blob: ac15bc8201dd64b68471f93648c138baf6babfd0 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Python extension-based XLA client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import threading
from absl.testing import absltest
import numpy as np
from tensorflow.compiler.xla.python import custom_call_for_test
from tensorflow.compiler.xla.python import xla_client
class ComputationTest(absltest.TestCase):
"""Base class for running an XLA Computation through the local client."""
def _NewComputation(self, name=None):
if name is None:
name = self.id()
return xla_client.ComputationBuilder(name)
def _Execute(self, c, arguments):
compiled_c = c.Build().Compile()
return xla_client.execute_with_python_values(compiled_c, arguments)
def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
assert expected is not None
result = self._Execute(c, arguments)
# Numpy's comparison methods are a bit too lenient by treating inputs as
# "array-like", meaning that scalar 4 will be happily compared equal to
# [[4]]. We'd like to be more strict so assert shapes as well.
self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
assert_func(result, expected)
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
def _ExecuteAndCompareClose(self,
c,
arguments=(),
expected=None,
rtol=1e-7,
atol=0):
self._ExecuteAndAssertWith(
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c,
arguments, expected)
def NumpyArrayF32(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
return np.array(*args, dtype=np.float32, **kwargs)
def NumpyArrayF64(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
return np.array(*args, dtype=np.float64, **kwargs)
def NumpyArrayS32(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
return np.array(*args, dtype=np.int32, **kwargs)
def NumpyArrayS64(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
return np.array(*args, dtype=np.int64, **kwargs)
def NumpyArrayBool(*args, **kwargs):
"""Convenience wrapper to create Numpy arrays with a np.bool dtype."""
return np.array(*args, dtype=np.bool, **kwargs)
class ComputationPrinting(absltest.TestCase):
def ExampleComputation(self):
builder = xla_client.ComputationBuilder("acomputation")
p0 = builder.ParameterFromNumpy(np.float32(0))
p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32))
builder.Mul(p0, p1)
return builder.Build()
def testComputationToHloText(self):
computation = self.ExampleComputation()
hlo_text = computation.GetHloText()
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
def testComputationToHloGraph(self):
computation = self.ExampleComputation()
hlo_dot_graph = computation.GetHloDotGraph()
self.assertTrue(hlo_dot_graph.startswith("digraph "))
class ComputationsWithConstantsTest(ComputationTest):
"""Tests focusing on Constant ops."""
def testConstantScalarSumS8(self):
c = self._NewComputation()
c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
self._ExecuteAndCompareExact(c, expected=np.int8(3))
def testConstantScalarSumF32(self):
c = self._NewComputation()
c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
self._ExecuteAndCompareClose(c, expected=4.25)
def testConstantScalarSumF64(self):
c = self._NewComputation()
c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
self._ExecuteAndCompareClose(c, expected=4.25)
def testConstantScalarSumS32(self):
c = self._NewComputation()
c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
self._ExecuteAndCompareClose(c, expected=3)
def testConstantScalarSumS64(self):
c = self._NewComputation()
c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
self._ExecuteAndCompareClose(c, expected=3)
def testConstantVectorMulF16(self):
c = self._NewComputation()
c.Mul(
c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)),
c.Constant(np.array([-1.2, 2, -2, -3], np.float16)))
self._ExecuteAndCompareClose(
c, expected=np.array([-3, 6.6, 2.4, -2.1], np.float16), rtol=2e-3)
def testConstantVectorMulF32(self):
c = self._NewComputation()
c.Mul(
c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
def testConstantVectorMulF64(self):
c = self._NewComputation()
c.Mul(
c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
def testConstantVectorScalarDivF32(self):
c = self._NewComputation()
c.Div(
c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
c.ConstantF32Scalar(2.0))
self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
def testConstantVectorScalarDivF64(self):
c = self._NewComputation()
c.Div(
c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
c.ConstantF64Scalar(2.0))
self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
def testConstantVectorScalarPowF32(self):
c = self._NewComputation()
c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
def testConstantVectorScalarPowF64(self):
c = self._NewComputation()
c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
def testIota(self):
c = self._NewComputation()
c.Iota(np.float32, 10)
self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32))
def testBroadcastedIota(self):
c = self._NewComputation()
c.BroadcastedIota(np.int64, (2, 3), 1)
expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64)
self._ExecuteAndCompareExact(c, expected=expected)
def testBooleanAnd(self):
c = self._NewComputation()
c.And(
c.Constant(NumpyArrayBool([True, False, True, False])),
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
def testBooleanOr(self):
c = self._NewComputation()
c.Or(
c.Constant(NumpyArrayBool([True, False, True, False])),
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
def testBooleanXor(self):
c = self._NewComputation()
c.Xor(
c.Constant(NumpyArrayBool([True, False, True, False])),
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
def testSum2DF32(self):
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
def testShiftLeft(self):
c = self._NewComputation()
c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2])))
self._ExecuteAndCompareClose(c, expected=[12])
def testShiftRightArithmetic(self):
c = self._NewComputation()
c.ShiftRightArithmetic(
c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1])))
self._ExecuteAndCompareClose(c, expected=[-1])
def testShiftRightLogical(self):
c = self._NewComputation()
c.ShiftRightLogical(
c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1])))
self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
def testSum2DF64(self):
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
def testSum2DWith1DBroadcastDim0F32(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 0 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([10, 20, 30])),
broadcast_dimensions=(0,))
self._ExecuteAndCompareClose(
c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
def testSum2DWith1DBroadcastDim0F64(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 0 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF64([10, 20, 30])),
broadcast_dimensions=(0,))
self._ExecuteAndCompareClose(
c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
def testSum2DWith1DBroadcastDim1F32(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 1 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([10, 20, 30])),
broadcast_dimensions=(1,))
self._ExecuteAndCompareClose(
c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
def testSum2DWith1DBroadcastDim1F64(self):
# sum of a 2D array with a 1D array where the latter is replicated across
# dimension 1 to match the former's shape.
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF64([10, 20, 30])),
broadcast_dimensions=(1,))
self._ExecuteAndCompareClose(
c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
def testConstantAxpyF32(self):
c = self._NewComputation()
c.Add(
c.Mul(
c.ConstantF32Scalar(2),
c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
c.Constant(NumpyArrayF32([100, -100, 200, -200])))
self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
def testConstantAxpyF64(self):
c = self._NewComputation()
c.Add(
c.Mul(
c.ConstantF64Scalar(2),
c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
c.Constant(NumpyArrayF64([100, -100, 200, -200])))
self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
def testCustomCall(self):
c = self._NewComputation()
for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
xla_client.register_custom_call_target(name, fn, platform="cpu")
c.CustomCall(
b"test_subtract_f32",
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
shape_with_layout=xla_client.Shape.array_shape(
np.dtype(np.float32), (), ()),
operand_shapes_with_layout=(
xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
))
self._ExecuteAndCompareClose(c, expected=0.75)
class ParametersTest(ComputationTest):
"""Tests focusing on Parameter ops and argument-passing."""
def setUp(self):
self.f32_scalar_2 = NumpyArrayF32(2.0)
self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
self.f64_scalar_2 = NumpyArrayF64(2.0)
self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
self.s32_scalar_3 = NumpyArrayS32(3)
self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
self.s64_scalar_3 = NumpyArrayS64(3)
self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
def testScalarTimesVectorAutonumberF32(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.f32_scalar_2)
p1 = c.ParameterFromNumpy(self.f32_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareClose(
c,
arguments=[self.f32_scalar_2, self.f32_4vector],
expected=[-4.6, 6.6, -8.6, 10.6])
def testScalarTimesVectorAutonumberF64(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.f64_scalar_2)
p1 = c.ParameterFromNumpy(self.f64_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareClose(
c,
arguments=[self.f64_scalar_2, self.f64_4vector],
expected=[-4.6, 6.6, -8.6, 10.6])
def testScalarTimesVectorS32(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.s32_scalar_3)
p1 = c.ParameterFromNumpy(self.s32_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareExact(
c,
arguments=[self.s32_scalar_3, self.s32_4vector],
expected=[30, 45, -6, 21])
def testScalarTimesVectorS64(self):
c = self._NewComputation()
p0 = c.ParameterFromNumpy(self.s64_scalar_3)
p1 = c.ParameterFromNumpy(self.s64_4vector)
c.Mul(p0, p1)
self._ExecuteAndCompareExact(
c,
arguments=[self.s64_scalar_3, self.s64_4vector],
expected=[30, 45, -6, 21])
def testScalarMinusVectorExplicitNumberingF32(self):
# Use explicit numbering and pass parameter_num first. Sub is used since
# it's not commutative and can help catch parameter reversal within the
# computation.
c = self._NewComputation()
p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
c.Sub(p1, p0)
self._ExecuteAndCompareClose(
c,
arguments=[self.f32_scalar_2, self.f32_4vector],
expected=[-4.3, 1.3, -6.3, 3.3])
def testScalarMinusVectorExplicitNumberingF64(self):
# Use explicit numbering and pass parameter_num first. Sub is used since
# it's not commutative and can help catch parameter reversal within the
# computation.
c = self._NewComputation()
p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
c.Sub(p1, p0)
self._ExecuteAndCompareClose(
c,
arguments=[self.f64_scalar_2, self.f64_4vector],
expected=[-4.3, 1.3, -6.3, 3.3])
class BufferTest(ComputationTest):
"""Tests focusing on execution with Buffers."""
def _Execute(self, c, arguments):
compiled_c = c.Build().Compile()
arg_buffers = [xla_client.Buffer.from_pyval(arg) for arg in arguments]
result_buffer = compiled_c.Execute(arg_buffers)
return result_buffer.to_py()
def testConstantSum(self):
c = self._NewComputation()
c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
self._ExecuteAndCompareClose(c, expected=4.25)
def testOneParameterSum(self):
c = self._NewComputation()
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
self._ExecuteAndCompareClose(
c, arguments=[NumpyArrayF32(1.11)], expected=4.25)
def testTwoParameterSum(self):
c = self._NewComputation()
c.Add(
c.ParameterFromNumpy(NumpyArrayF32(0.)),
c.ParameterFromNumpy(NumpyArrayF32(0.)))
self._ExecuteAndCompareClose(
c, arguments=[NumpyArrayF32(1.11),
NumpyArrayF32(3.14)], expected=4.25)
def testCannotCallWithDeletedBuffers(self):
c = self._NewComputation()
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
arg = NumpyArrayF32(1.11)
compiled_c = c.Build().Compile()
arg_buffer = xla_client.Buffer.from_pyval(arg)
arg_buffer.delete()
with self.assertRaises(RuntimeError):
compiled_c.Execute([arg_buffer])
def testDestructureTupleEmpty(self):
t = ()
local_buffer = xla_client.Buffer.from_pyval(t)
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertEmpty(pieces)
def testDestructureTupleOneArrayElement(self):
t = (np.array([1, 2, 3, 4], dtype=np.int32),)
local_buffer = xla_client.Buffer.from_pyval(t)
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertLen(pieces, 1)
array = pieces[0]
got = array.to_py()
want = NumpyArrayS32([1, 2, 3, 4])
np.testing.assert_equal(want, got)
def testDestructureTupleTwoArrayElementDifferentType(self):
t = (
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
np.array([2, 3, 4, 5], dtype=np.int32),
)
local_buffer = xla_client.Buffer.from_pyval(t)
# Run the test twice to verify that the original tuple buffer remains valid
# even after destructuring.
for _ in range(2):
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertLen(pieces, 2)
array0, array1 = pieces
got = array0.to_py()
want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0])
np.testing.assert_equal(want, got)
got = array1.to_py()
want = NumpyArrayS32([2, 3, 4, 5])
np.testing.assert_equal(want, got)
def testDestructureTupleNested(self):
t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5]))
local_buffer = xla_client.Buffer.from_pyval(t)
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertLen(pieces, 2)
tuple0, array1 = pieces
got = array1.to_py()
want = NumpyArrayS32([5])
np.testing.assert_equal(want, got)
got = tuple0.to_py()
self.assertEqual(type(got), tuple)
self.assertLen(got, 2)
np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
def testMakeTuple(self):
t = (
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
np.array([2, 3, 4, 5], dtype=np.int32),
)
b0 = xla_client.Buffer.from_pyval(t[0])
b1 = xla_client.Buffer.from_pyval(t[1])
btup = xla_client.Buffer.make_tuple([b0, b1], device=0)
pieces = btup.destructure()
self.assertLen(pieces, 2)
array0, array1 = pieces
np.testing.assert_equal(
np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py())
np.testing.assert_equal(
np.array([2, 3, 4, 5], dtype=np.int32), array1.to_py())
def testShape(self):
pyval = np.array([[1., 2.]], np.float32)
local_buffer = xla_client.Buffer.from_pyval(pyval)
xla_shape = local_buffer.shape()
self.assertEqual(xla_shape.dimensions(), (1, 2))
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
def testBlockHostUntilReadyWorks(self):
arg = np.array([[1., 2.]], np.float32)
arg_buffer = xla_client.Buffer.from_pyval(arg)
arg_buffer.block_host_until_ready()
# This test merely checks that nothing goes awry when we call
# block_host_until_ready(); it's difficult to test anything else.
def testCopyToHost(self):
arg0 = np.array([[1., 2.]], np.float32)
arg1 = np.array([[3., 4.]], np.float32)
arg0_buffer = xla_client.Buffer.from_pyval(arg0)
arg1_buffer = xla_client.Buffer.from_pyval(arg1)
# Prefetch two buffers using copy_to_host_async, and then retrieve their
# values using to_py.
arg0_buffer.copy_to_host_async()
arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything.
arg1_buffer.copy_to_host_async()
np.testing.assert_equal(arg0, arg0_buffer.to_py())
np.testing.assert_equal(arg1, arg1_buffer.to_py())
# copy_to_host_async does nothing after to_py is called.
arg0_buffer.copy_to_host_async()
np.testing.assert_equal(arg0, arg0_buffer.to_py())
class SingleOpTest(ComputationTest):
"""Tests for single ops.
The goal here is smoke testing - to exercise the most basic functionality of
single XLA ops. As minimal as possible number of additional ops are added
around the op being tested.
"""
def testConcatenateF32(self):
c = self._NewComputation()
args = (
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])),
)
c.Concatenate(args, dimension=0)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def testConcatenateF64(self):
c = self._NewComputation()
args = (
c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])),
)
c.Concatenate(args, dimension=0)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def testConvertElementType(self):
xla_types = {
np.bool: xla_client.PrimitiveType.PRED,
np.int32: xla_client.PrimitiveType.S32,
np.int64: xla_client.PrimitiveType.S64,
np.float32: xla_client.PrimitiveType.F32,
np.float64: xla_client.PrimitiveType.F64,
}
def _ConvertAndTest(template, src_dtype, dst_dtype):
c = self._NewComputation()
x = c.Constant(np.array(template, dtype=src_dtype))
c.ConvertElementType(x, xla_types[dst_dtype])
result = xla_client.execute_with_python_values(c.Build().Compile())
expected = np.array(template, dtype=dst_dtype)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(result.dtype, expected.dtype)
np.testing.assert_equal(result, expected)
x = [0, 1, 0, 0, 1]
for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
_ConvertAndTest(x, src_dtype, dst_dtype)
def testBitcastConvertType(self):
xla_x32_types = {
np.int32: xla_client.PrimitiveType.S32,
np.float32: xla_client.PrimitiveType.F32,
}
xla_x64_types = {
np.int64: xla_client.PrimitiveType.S64,
np.float64: xla_client.PrimitiveType.F64,
}
def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype):
c = self._NewComputation()
x = c.Constant(np.array(template, dtype=src_dtype))
c.BitcastConvertType(x, dst_etype)
result = xla_client.execute_with_python_values(c.Build().Compile())
expected = np.array(template, src_dtype).view(dst_dtype)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(result.dtype, expected.dtype)
np.testing.assert_equal(result, expected)
x = [0, 1, 0, 0, 1]
for xla_types in [xla_x32_types, xla_x64_types]:
for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
_ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype])
# TODO(b/123523486) implement AllToAll on CPU
def DISABLED_testAllToAllOneReplica(self):
samples = [
NumpyArrayF32([97.0]),
NumpyArrayF32([64.0, 117.0]),
NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
]
for lhs in samples[:1]:
c = self._NewComputation()
c.AllToAll(c.Constant(lhs), 0, 0)
self._ExecuteAndCompareExact(c, expected=lhs)
def testCrossReplicaSumOneReplica(self):
samples = [
NumpyArrayF32(42.0),
NumpyArrayF32([97.0]),
NumpyArrayF32([64.0, 117.0]),
NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
]
for lhs in samples:
c = self._NewComputation()
c.CrossReplicaSum(c.Constant(lhs))
self._ExecuteAndCompareExact(c, expected=lhs)
def testReplicaId(self):
c = self._NewComputation()
_ = c.ReplicaId()
self._ExecuteAndCompareExact(c, expected=0)
def testCrossReplicaSumOneReplicaWithSingletonGroup(self):
samples = [
NumpyArrayF32(42.0),
NumpyArrayF32([97.0]),
NumpyArrayF32([64.0, 117.0]),
NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
]
for lhs in samples:
c = self._NewComputation()
c.CrossReplicaSum(c.Constant(lhs), [[0]])
self._ExecuteAndCompareExact(c, expected=lhs)
def testDotMatrixVectorF32(self):
c = self._NewComputation()
lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF32([[10.0], [20.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotMatrixVectorF64(self):
c = self._NewComputation()
lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF64([[10.0], [20.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotMatrixMatrixF32(self):
c = self._NewComputation()
lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotMatrixMatrixF64(self):
c = self._NewComputation()
lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotGeneral(self):
c = self._NewComputation()
rng = np.random.RandomState(0)
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
dimension_numbers = (([2], [1]), ([0], [0]))
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
def testDotGeneralWithDotDimensionNumbersProto(self):
c = self._NewComputation()
rng = np.random.RandomState(0)
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
dimension_numbers = xla_client.DotDimensionNumbers()
dimension_numbers.lhs_contracting_dimensions.append(2)
dimension_numbers.rhs_contracting_dimensions.append(1)
dimension_numbers.lhs_batch_dimensions.append(0)
dimension_numbers.rhs_batch_dimensions.append(0)
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
def testDotGeneralWithPrecisionConfig(self):
c = self._NewComputation()
rng = np.random.RandomState(0)
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
dimension_numbers = (([2], [1]), ([0], [0]))
config = xla_client.PrecisionConfig()
config.operand_precision.append(config.Precision.HIGH)
config.operand_precision.append(config.Precision.HIGHEST)
c.DotGeneral(
c.Constant(lhs),
c.Constant(rhs),
dimension_numbers,
precision_config=config)
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
def testConvF32Same(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(
c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME)
result = np.array([[[
[640., 700., 760., 300.],
[880., 940., 1000., 380.],
[1120., 1180., 1240., 460.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvF32Valid(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(
c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID)
result = np.array([[[
[640., 700., 760.],
[1120., 1180., 1240.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvWithGeneralPaddingF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
c.ConvWithGeneralPadding(
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
rhs_dilation)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NCHW", "OIHW", "NCHW")
c.ConvGeneralDilated(
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
rhs_dilation, dimension_numbers)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedF32WithPrecisionConfig(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NCHW", "OIHW", "NCHW")
config = xla_client.PrecisionConfig()
config.operand_precision.append(config.Precision.HIGHEST)
config.operand_precision.append(config.Precision.DEFAULT)
c.ConvGeneralDilated(
c.Constant(lhs),
c.Constant(rhs),
strides,
pads,
lhs_dilation,
rhs_dilation,
dimension_numbers,
precision_config=config)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedPermutedF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NHWC", "OIHW", "CWNH")
c.ConvGeneralDilated(
c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides,
pads, lhs_dilation, rhs_dilation, dimension_numbers)
result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
def testConvGeneralDilatedGroupedConvolutionF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 2, 3)
rhs = a(2, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NCHW", "OIHW", "NCHW")
feature_group_count = 2
c.ConvGeneralDilated(
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
], [
[0., 0., 0.],
[330., 380., 160.],
[0., 0., 0.],
[480., 530., 220.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
c.Not(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=~arr)
def testCountLeadingZeros(self):
c = self._NewComputation()
arr = NumpyArrayS32([0x7FFF, 0x12345678])
c.Clz(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=[17, 3])
def testExp(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
def testExpm1(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Expm1(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.expm1(arr))
def testRound(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Round(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.round(arr))
def testLog(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Log(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log(arr))
def testLog1p(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Log1p(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log1p(arr))
def testNeg(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Neg(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=-arr)
def testFloor(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Floor(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.floor(arr))
def testCeil(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Ceil(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
def testAbs(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
c.Abs(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.abs(arr))
def testTanh(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Tanh(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
def testTrans(self):
def _TransposeAndTest(array):
c = self._NewComputation()
c.Trans(c.Constant(array))
self._ExecuteAndCompareClose(c, expected=array.T)
# Test square and non-square matrices in both default (C) and F orders.
for array_fun in [NumpyArrayF32, NumpyArrayF64]:
_TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
_TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
_TransposeAndTest(array_fun([[1, 2], [4, 5]]))
_TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
def testTranspose(self):
def _TransposeAndTest(array, permutation):
c = self._NewComputation()
c.Transpose(c.Constant(array), permutation)
expected = np.transpose(array, permutation)
self._ExecuteAndCompareClose(c, expected=expected)
_TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
_TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
_TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
_TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
for permutation in itertools.permutations(range(arr.ndim)):
_TransposeAndTest(arr, permutation)
_TransposeAndTest(np.asfortranarray(arr), permutation)
def testEq(self):
c = self._NewComputation()
c.Eq(
c.Constant(NumpyArrayS32([1, 2, 3, 4])),
c.Constant(NumpyArrayS32([4, 2, 3, 1])))
self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
def testNe(self):
c = self._NewComputation()
c.Ne(
c.Constant(NumpyArrayS32([1, 2, 3, 4])),
c.Constant(NumpyArrayS32([4, 2, 3, 1])))
self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
c.Ne(
c.Constant(NumpyArrayF32([-2.0, 0.0,
float("nan"),
float("nan")])),
c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
self._ExecuteAndAssertWith(
np.testing.assert_allclose, c, (), expected=[True, False, True, True])
def testGt(self):
c = self._NewComputation()
c.Gt(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
def testGe(self):
c = self._NewComputation()
c.Ge(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
def testLt(self):
c = self._NewComputation()
c.Lt(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
def testLe(self):
c = self._NewComputation()
c.Le(
c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
def testMax(self):
c = self._NewComputation()
c.Max(
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
def testMaxExplicitBroadcastDim0(self):
c = self._NewComputation()
c.Max(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([3, 4, 5])),
broadcast_dimensions=(0,))
self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
def testMaxExplicitBroadcastDim1(self):
c = self._NewComputation()
c.Max(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayF32([3, 4, 5])),
broadcast_dimensions=(1,))
self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
def testMin(self):
c = self._NewComputation()
c.Min(
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
def testPad(self):
c = self._NewComputation()
c.Pad(
c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)])
self._ExecuteAndCompareClose(
c,
expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
def testPadWithPaddingConfig(self):
c = self._NewComputation()
padding_config = xla_client.PaddingConfig()
for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
dimension = xla_client.PaddingConfigDimension()
dimension.edge_padding_low = lo
dimension.edge_padding_high = hi
dimension.interior_padding = interior
padding_config.dimensions.append(dimension)
c.Pad(
c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
c.Constant(NumpyArrayF32(0.0)), padding_config)
self._ExecuteAndCompareClose(
c,
expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
def testReshape(self):
c = self._NewComputation()
c.Reshape(
c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
dimensions=[0, 1],
new_sizes=[2, 3])
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
def testCollapse(self):
c = self._NewComputation()
c.Collapse(
c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
dimensions=[1, 2])
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]])
def testRev(self):
c = self._NewComputation()
c.Rev(
c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
dimensions=[0, 2])
self._ExecuteAndCompareExact(
c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
def testReducePrecision(self):
c = self._NewComputation()
c.ReducePrecision(
c.Constant(NumpyArrayF32([float.fromhex("0x1.32fffep-3")])),
exponent_bits=8,
mantissa_bits=7)
self._ExecuteAndCompareClose(c, expected=[float.fromhex("0x1.32p-3")])
def testClampF32(self):
c = self._NewComputation()
c.Clamp(
c.Constant(NumpyArrayF32(-1)),
c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
c.Constant(NumpyArrayF32(2)))
self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
def testClampS32(self):
c = self._NewComputation()
c.Clamp(
c.Constant(NumpyArrayS32(-1)),
c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
c.Constant(NumpyArrayS32(2)))
self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
def testSelect(self):
c = self._NewComputation()
c.Select(
c.Constant(NumpyArrayBool([True, False, False, True, False])),
c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
def testSlice(self):
c = self._NewComputation()
c.Slice(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
[3, 2])
self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
def testSliceInDim(self):
c = self._NewComputation()
c.SliceInDim(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
start_index=1,
limit_index=2,
stride=1,
dimno=1)
self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]])
c.SliceInDim(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
start_index=0,
limit_index=3,
stride=2,
dimno=0)
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]])
def testDynamicSlice(self):
c = self._NewComputation()
c.DynamicSlice(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayS32([1, 0])), [2, 2])
self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
def testDynamicUpdateSlice(self):
c = self._NewComputation()
c.DynamicUpdateSlice(
c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
c.Constant(NumpyArrayS32([1, 1])))
self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
def testTuple(self):
c = self._NewComputation()
c.Tuple(
c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
c.Constant(NumpyArrayBool([True, False, False, True])))
result = xla_client.execute_with_python_values(c.Build().Compile())
self.assertIsInstance(result, tuple)
np.testing.assert_equal(result[0], 42)
np.testing.assert_allclose(result[1], [1.0, 2.0])
np.testing.assert_equal(result[2], [True, False, False, True])
def testGetTupleElement(self):
c = self._NewComputation()
c.GetTupleElement(
c.Tuple(
c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
def testBroadcast(self):
c = self._NewComputation()
c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
self._ExecuteAndCompareExact(
c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
def testBroadcastInDim(self):
c = self._NewComputation()
c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0])
self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]])
c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1])
self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]])
def testRngNormal(self):
shape = (2, 3)
c = self._NewComputation()
c.RngNormal(
c.Constant(NumpyArrayF32(0.)),
c.Constant(NumpyArrayF32(1.)),
dims=shape)
result = xla_client.execute_with_python_values(c.Build().Compile())
# since the result is random, we just check shape and uniqueness
self.assertEqual(result.shape, shape)
self.assertLen(np.unique(result), np.prod(shape))
def testRngUniformF32(self):
lo, hi = 2., 4.
shape = (2, 3)
c = self._NewComputation()
c.RngUniform(
c.Constant(NumpyArrayF32(lo)),
c.Constant(NumpyArrayF32(hi)),
dims=shape)
result = xla_client.execute_with_python_values(c.Build().Compile())
# since the result is random, we just check shape, uniqueness, and range
self.assertEqual(result.shape, shape)
self.assertLen(np.unique(result), np.prod(shape))
self.assertTrue(np.all(lo <= result))
self.assertTrue(np.all(result < hi))
def testRngUniformS32(self):
lo, hi = 2, 4
shape = (2, 3)
c = self._NewComputation()
c.RngUniform(
c.Constant(NumpyArrayS32(lo)),
c.Constant(NumpyArrayS32(hi)),
dims=shape)
result = xla_client.execute_with_python_values(c.Build().Compile())
# since the result is random, we just check shape, integrality, and range
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, np.int32)
self.assertTrue(np.all(lo <= result))
self.assertTrue(np.all(result < hi))
def testCholesky(self):
l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
dtype=np.float32)
c = self._NewComputation()
c.Cholesky(c.Constant(np.dot(l, l.T)))
self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4)
def testSort(self):
keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
c = self._NewComputation()
c.Sort(c.Constant(keys))
self._ExecuteAndCompareClose(
c, expected=np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32))
def testSortKeyVal(self):
keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
c = self._NewComputation()
c.SortKeyVal(c.Constant(keys), c.Constant(values), dimension=0)
result = xla_client.execute_with_python_values(c.Build().Compile())
self.assertIsInstance(result, tuple)
np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
def testQR(self):
a = np.array(
[[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
dtype=np.float32)
c = self._NewComputation()
c.QR(c.Constant(a), full_matrices=True)
q, r = self._Execute(c, ())
np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
def testEigh(self):
a = np.array(
[[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
dtype=np.float32)
a = (a + a.T) / 2
c = self._NewComputation()
c.Eigh(c.Constant(a), full_matrices=True)
# TODO(b/129396575): Turn this test back on when it passes without fastmath.
# v, w = self._Execute(c, ())
# self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3)
def testSVD(self):
a = np.array(
[[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
dtype=np.float32)
c = self._NewComputation()
c.SVD(c.Constant(a))
u, d, v = self._Execute(c, ())
self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3)
def testTriangularSolve(self):
a_vals = np.array(
[[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
dtype=np.float32)
b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
dtype=np.float32)
c = self._NewComputation()
c.TriangularSolve(
c.Constant(a_vals),
c.Constant(b_vals),
left_side=False,
lower=True,
transpose_a=True)
self._ExecuteAndCompareClose(
c,
expected=np.array([
[0.5, 0.08333334, 0.04629629, 0.03367003],
[2.5, -0.25, -0.1388889, -0.1010101],
[4.5, -0.58333331, -0.32407406, -0.23569024],
],
dtype=np.float32),
rtol=1e-4)
def testIsConstant(self):
c = self._NewComputation()
a = c.ConstantS32Scalar(3)
b = c.ConstantS32Scalar(1)
x = c.ParameterFromNumpy(NumpyArrayS32(0))
const_expr = c.Sub(b, a)
non_const_expr = c.Mul(const_expr, x)
self.assertTrue(c.IsConstant(const_expr))
self.assertFalse(c.IsConstant(non_const_expr))
# self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564)
def testGather(self):
a = np.arange(9).astype(np.int32).reshape((3, 3))
indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32)
dnums = xla_client.GatherDimensionNumbers()
dnums.offset_dims.append(1)
dnums.offset_dims.append(2)
dnums.start_index_map.append(0)
dnums.start_index_map.append(1)
dnums.index_vector_dim = 2
c = self._NewComputation()
c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1])
g = self._Execute(c, ())
expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
np.testing.assert_allclose(g, expected, rtol=1e-4)
def testFft(self):
shape = [2, 3, 4, 5]
rng = np.random.RandomState(0)
a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
a = a.astype(np.complex64)
# FFT
c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)),
rtol=1e-4)
# IFFT
c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)),
rtol=1e-4)
# RFFT
b = rng.randn(*shape).astype(np.float32)
c = self._NewComputation()
c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)),
rtol=1e-4)
# IRFFT
c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8])
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)),
rtol=1e-4)
class EmbeddedComputationsTest(ComputationTest):
"""Tests for XLA graphs with embedded computations (such as maps)."""
def _CreateConstantS32Computation(self):
"""Computation (f32) -> s32 that returns a constant 1 for any input."""
c = self._NewComputation("constant_s32_one")
# TODO(eliben): consider adding a nicer way to create new parameters without
# having to create dummy Numpy arrays or populating Shape messages. Perhaps
# we need our own (Python-client-own) way to represent Shapes conveniently.
c.ParameterFromNumpy(NumpyArrayF32(0))
c.ConstantS32Scalar(1)
return c.Build()
def _CreateConstantS64Computation(self):
"""Computation (f64) -> s64 that returns a constant 1 for any input."""
c = self._NewComputation("constant_s64_one")
# TODO(eliben): consider adding a nicer way to create new parameters without
# having to create dummy Numpy arrays or populating Shape messages. Perhaps
# we need our own (Python-client-own) way to represent Shapes conveniently.
c.ParameterFromNumpy(NumpyArrayF64(0))
c.ConstantS64Scalar(1)
return c.Build()
def _CreateConstantF32Computation(self):
"""Computation (f32) -> f32 that returns a constant 1.0 for any input."""
c = self._NewComputation("constant_f32_one")
c.ParameterFromNumpy(NumpyArrayF32(0))
c.ConstantF32Scalar(1.0)
return c.Build()
def _CreateConstantF64Computation(self):
"""Computation (f64) -> f64 that returns a constant 1.0 for any input."""
c = self._NewComputation("constant_f64_one")
c.ParameterFromNumpy(NumpyArrayF64(0))
c.ConstantF64Scalar(1.0)
return c.Build()
def _CreateMulF32By2Computation(self):
"""Computation (f32) -> f32 that multiplies its parameter by 2."""
c = self._NewComputation("mul_f32_by2")
c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
return c.Build()
def _CreateMulF32ByParamComputation(self):
"""Computation (f32) -> f32 that multiplies one parameter by the other."""
c = self._NewComputation("mul_f32_by_param")
c.Mul(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateMulF64By2Computation(self):
"""Computation (f64) -> f64 that multiplies its parameter by 2."""
c = self._NewComputation("mul_f64_by2")
c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
return c.Build()
def _CreateBinaryAddS32Computation(self):
"""Computation (s32, s32) -> s32 that adds its two parameters."""
c = self._NewComputation("add_param0_by_param1")
c.Add(
c.ParameterFromNumpy(NumpyArrayS32(0)),
c.ParameterFromNumpy(NumpyArrayS32(0)))
return c.Build()
def _CreateBinaryAddF32Computation(self):
"""Computation (f32, f32) -> f32 that adds its two parameters."""
c = self._NewComputation("add_param0_by_param1")
c.Add(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateBinaryAddF64Computation(self):
"""Computation (f64, f64) -> f64 that adds its two parameters."""
c = self._NewComputation("add_param0_by_param1")
c.Add(
c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
return c.Build()
def _CreateBinaryDivF32Computation(self):
"""Computation (f32, f32) -> f32 that divides its two parameters."""
c = self._NewComputation("div_param0_by_param1")
c.Div(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateBinaryDivF64Computation(self):
"""Computation (f64, f64) -> f64 that divides its two parameters."""
c = self._NewComputation("div_param0_by_param1")
c.Div(
c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
return c.Build()
def _CreateTestF32Lt10Computation(self):
"""Computation (f32) -> bool that tests if its parameter is less than 10."""
c = self._NewComputation("test_f32_lt_10")
c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
return c.Build()
def _CreateTestF64Lt10Computation(self):
"""Computation (f64) -> bool that tests if its parameter is less than 10."""
c = self._NewComputation("test_f64_lt_10")
c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
return c.Build()
def _CreateBinaryGeF32Computation(self):
"""Computation (f32, f32) -> bool that tests first_param >= second_param."""
c = self._NewComputation("param0_lt_param1")
c.Ge(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateBinaryGeF64Computation(self):
"""Computation (f64, f64) -> bool that tests first_param >= second_param."""
c = self._NewComputation("param0_lt_param1")
c.Ge(
c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
return c.Build()
def _MakeSample3DArrayF32(self):
return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
def _MakeSample3DArrayF64(self):
return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
def testCallF32(self):
c = self._NewComputation()
c.Call(
self._CreateMulF32By2Computation(),
operands=(c.ConstantF32Scalar(5.0),))
self._ExecuteAndCompareClose(c, expected=10.0)
def testCallF64(self):
c = self._NewComputation()
c.Call(
self._CreateMulF64By2Computation(),
operands=(c.ConstantF64Scalar(5.0),))
self._ExecuteAndCompareClose(c, expected=10.0)
def testMapEachElementToS32Constant(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantS32Computation(), [0])
self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
def testMapEachElementToS64Constant(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantS64Computation(), [0])
self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
def testMapMulBy2F32(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
self._CreateMulF32By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
def testMapMulBy2F64(self):
c = self._NewComputation()
c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
self._CreateMulF64By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
def testSimpleMapChainF32(self):
# Chains a map of constant-f32 with a map of mul-by-2
c = self._NewComputation()
const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantF32Computation(), [0])
c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
def testSimpleMapChainF64(self):
# Chains a map of constant-f64 with a map of mul-by-2
c = self._NewComputation()
const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
self._CreateConstantF64Computation(), [0])
c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
def testDivVectorsWithMapF32(self):
c = self._NewComputation()
c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
self._CreateBinaryDivF32Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
def testDivVectorsWithMapF64(self):
c = self._NewComputation()
c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
self._CreateBinaryDivF64Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
def testSelectAndScatterF32(self):
c = self._NewComputation()
c.SelectAndScatter(
c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
select=self._CreateBinaryGeF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID,
source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
init_value=c.Constant(NumpyArrayF32(1)),
scatter=self._CreateBinaryAddF32Computation())
self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
def testSelectAndScatterF64(self):
c = self._NewComputation()
c.SelectAndScatter(
c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
select=self._CreateBinaryGeF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID,
source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
init_value=c.Constant(NumpyArrayF64(1)),
scatter=self._CreateBinaryAddF64Computation())
self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
def testReduce1DtoScalarF32(self):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=10)
def testReduce1DtoScalarF64(self):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=10)
def testReduce2DTo1DDim0F32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
def testReduce2DTo1DDim0F64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=[0])
self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
def testReduce2DTo1DDim1F32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=[1])
self._ExecuteAndCompareClose(c, expected=[6, 15])
def testReduce2DTo1DDim1F64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=[1])
self._ExecuteAndCompareClose(c, expected=[6, 15])
def testReduce3DAllPossibleWaysF32(self):
input_array = self._MakeSample3DArrayF32()
def _ReduceAndTest(*dims):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
dimensions=dims)
self._ExecuteAndCompareClose(
c, expected=np.sum(input_array, axis=tuple(dims)))
_ReduceAndTest(0)
_ReduceAndTest(0, 1)
_ReduceAndTest(0, 2)
_ReduceAndTest(1, 2)
_ReduceAndTest(0, 1, 2)
def testReduce3DAllPossibleWaysF64(self):
input_array = self._MakeSample3DArrayF64()
def _ReduceAndTest(*dims):
c = self._NewComputation()
c.Reduce(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
dimensions=dims)
self._ExecuteAndCompareClose(
c, expected=np.sum(input_array, axis=tuple(dims)))
_ReduceAndTest(0)
_ReduceAndTest(0)
_ReduceAndTest(0, 1)
_ReduceAndTest(0, 2)
_ReduceAndTest(1, 2)
_ReduceAndTest(0, 1, 2)
def testReduceWindowValidUnitStridesF32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
def testReduceWindowSameUnitStridesF32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.SAME)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
def testReduceWindowValidGeneralStridesF32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
def testReduceWindowValidUnitStridesF64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
def testReduceWindowSameUnitStridesF64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.SAME)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
def testReduceWindowValidGeneralStridesF64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
def testWhileF32(self):
cond = self._CreateTestF32Lt10Computation()
body = self._CreateMulF32By2Computation()
c = self._NewComputation()
init = c.ConstantF32Scalar(1.)
c.While(cond, body, init)
self._ExecuteAndCompareClose(c, expected=16.)
def testWhileF64(self):
cond = self._CreateTestF64Lt10Computation()
body = self._CreateMulF64By2Computation()
c = self._NewComputation()
init = c.ConstantF64Scalar(1.)
c.While(cond, body, init)
self._ExecuteAndCompareClose(c, expected=16.)
def testConditionalTrue(self):
c = self._NewComputation()
pred = c.ConstantPredScalar(True)
true_operand = c.ConstantF32Scalar(3.)
true_computation = self._CreateMulF32By2Computation()
false_operand = c.ConstantF32Scalar(2.)
false_computation = self._CreateConstantF32Computation()
c.Conditional(pred, true_operand, true_computation, false_operand,
false_computation)
self._ExecuteAndCompareClose(c, expected=6.)
def testConditionalFalse(self):
c = self._NewComputation()
pred = c.ConstantPredScalar(False)
true_operand = c.ConstantF32Scalar(3.)
true_computation = self._CreateMulF32By2Computation()
false_operand = c.ConstantF32Scalar(2.)
false_computation = self._CreateConstantF32Computation()
c.Conditional(pred, true_operand, true_computation, false_operand,
false_computation)
self._ExecuteAndCompareClose(c, expected=1.)
def testInfeedS32Values(self):
to_infeed = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
c.Infeed(xla_client.shape_from_pyval(to_infeed[0]))
compiled_c = c.Build().Compile()
for item in to_infeed:
xla_client.transfer_to_infeed(item)
for item in to_infeed:
result = xla_client.execute_with_python_values(compiled_c)
self.assertEqual(result, item)
def testInfeedThenOutfeedS32(self):
to_round_trip = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
x = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0]))
c.Outfeed(x)
compiled_c = c.Build().Compile()
for want in to_round_trip:
execution = threading.Thread(target=lambda: compiled_c.Execute([]))
execution.start()
xla_client.transfer_to_infeed(want)
got = xla_client.transfer_from_outfeed(
xla_client.shape_from_pyval(to_round_trip[0]))
execution.join()
self.assertEqual(want, got)
def testScatter(self):
a = np.arange(9).astype(np.int32).reshape((3, 3))
scatter_indices = np.array([0, 2], dtype=np.int32)
updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32)
dnums = xla_client.ScatterDimensionNumbers()
dnums.update_window_dims.append(1)
dnums.inserted_window_dims.append(0)
dnums.scatter_dims_to_operand_dims.append(0)
dnums.index_vector_dim = 1
c = self._NewComputation()
c.Scatter(
c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
self._CreateBinaryAddS32Computation(), dnums)
expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32)
self._ExecuteAndCompareClose(c, expected=expected)
class ErrorTest(ComputationTest):
def setUp(self):
self.f32_scalar_2 = NumpyArrayF32(2.0)
self.s32_scalar_2 = NumpyArrayS32(2)
def testCompileWithWrongElementTypeInLayout(self):
c = self._NewComputation()
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
c.ParameterFromNumpy(self.s32_scalar_2)
c.ClearOpMetadata()
options = xla_client.CompileOptions()
options.argument_layouts = [
xla_client.Shape.array_shape(np.dtype(np.float32), [])
]
def TestFun():
return c.Build().Compile(compile_options=options)
self.assertRaisesRegexp(
RuntimeError, r".*Invalid argument shape.*"
r"expected s32\[\], got f32\[\].*", TestFun)
def testInvokeWithWrongElementType(self):
c = self._NewComputation()
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
c.ParameterFromNumpy(self.s32_scalar_2)
c.ClearOpMetadata()
def TestFun():
return xla_client.execute_with_python_values(c.Build().Compile(),
[self.f32_scalar_2])
self.assertRaisesRegexp(
RuntimeError, r"Invalid argument: Argument does not match.*"
r"want s32\[\], got f32\[\].*", TestFun)
class ComputationRootTest(ComputationTest):
"""Tests related to setting the root of the computation."""
def testComputationRootDifferentFromLastOp(self):
c = self._NewComputation()
x = c.ParameterFromNumpy(NumpyArrayF32(2.0))
result = c.Add(x, c.ConstantF32Scalar(3.14))
extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable
arg = NumpyArrayF32(1.0)
compiled_c = c.Build(result).Compile()
ans = xla_client.execute_with_python_values(compiled_c, [arg])
np.testing.assert_allclose(ans, 4.14)
if __name__ == "__main__":
absltest.main()