blob: 3f5703510a31b9d557e8b99a1a6c0e082a7b1a49 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// Computes the LU decomposition with partial pivoting of a batch of matrices.
//
// Given a (batched) matrix a with shape [..., m, n], computes the matrix
// decomposition A = P @ L @ U where P is a permutation matrix, L is a
// lower-triangular matrix with unit diagonal entries, and U is an
// upper-triangular matrix.
//
// L and U are returned as a single matrix [..., m, n] containing both L and U
// packed in the same array. The unit diagonal of L is not represented
// explicitly.
//
// The permutation matrix P is returned in two forms, both as `pivots`, which is
// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the
// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array
// which gives the permutation to apply to the rows. We return both
// representations because they are each useful for different purposes; `pivots`
// is useful for computing the sign of a determinant, whereas `permutation` can
// be used via a Gather operation to permute the rows of a matrix.
//
// This method is only implemented on TPU at the moment.
// TODO(b/168208200): the implementation only supports F32 arrays. Handle the
// complex case.
struct LuDecompositionResult {
// The LU decomposition, with both L and U packed into an array with shape
// [..., m, n].
XlaOp lu;
// An array of shape s32[..., min(m, n)] containing the pivot rows.
XlaOp pivots;
// An array of shape s32[..., m], containing an another representation of the
// pivots as a permutation.
XlaOp permutation;
};
LuDecompositionResult LuDecomposition(XlaOp a);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_