| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/cholesky_expander.h" |
| |
| #include <memory> |
| #include <vector> |
| |
| #include "tensorflow/compiler/xla/client/lib/arithmetic.h" |
| #include "tensorflow/compiler/xla/client/lib/constants.h" |
| #include "tensorflow/compiler/xla/client/lib/loops.h" |
| #include "tensorflow/compiler/xla/client/lib/math.h" |
| #include "tensorflow/compiler/xla/client/lib/matrix.h" |
| #include "tensorflow/compiler/xla/client/lib/slicing.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/primitive_util.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| |
| namespace xla { |
| |
| // The Cholesky–Banachiewicz algorithm. See |
| // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms |
| // for a description. |
| // |
| // def cholesky_unblocked(a): |
| // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] |
| // n = a.shape[-2] |
| // l = np.zeros_like(a) |
| // for j in xrange(n): |
| // mask = np.zeros_like(a) |
| // mask[i, k] == 1 when i >= k and k == j |
| // l_square = np.dot(l, l_t) |
| // temp = a - l_square |
| // l[..., j, j] = temp(j, j) |
| // l = temp / l[..., j, j) * mask + l |
| // return l |
| // Returns a (result, error) pair. |
| StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked( |
| XlaOp a, PrecisionConfig::Precision precision) { |
| XlaBuilder* builder = a.builder(); |
| TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); |
| const int ndims = a_shape.rank(); |
| const int64 n = ShapeUtil::GetDimension(a_shape, -1); |
| std::vector<int64> error_dims(a_shape.dimensions().begin(), |
| a_shape.dimensions().end()); |
| error_dims.back() = error_dims.at(ndims - 2) = 1; |
| |
| auto major_dims = AsInt64Slice(a_shape.dimensions()) |
| .subspan( |
| /*pos=*/0, |
| /*len=*/ndims - 2); |
| |
| auto matrix_dims = AsInt64Slice(a_shape.dimensions()) |
| .subspan( |
| /*pos=*/0, |
| /*len=*/ndims); |
| |
| XlaOp l = ZerosLike(a); |
| |
| // Construct the for loop body to iterate over rows. |
| auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars, |
| XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> { |
| std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end()); |
| std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end()); |
| auto body_a = loop_vars[0]; |
| auto body_l = loop_vars[1]; |
| auto seen_error = loop_vars[2]; |
| auto iota_row = |
| Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1); |
| auto iota_col = |
| Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2); |
| |
| auto mask_pred = Ge(iota_col, iota_row); |
| mask_pred = And(mask_pred, Eq(iota_row, i)); |
| auto mask_zeros = |
| Zeros(body_builder, |
| ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); |
| // L * L.T, This matrix has of a lot of multiplying with zero |
| // (namely, L[:, j:] = 0) and redundant computation, but it is faster |
| // than slice. |
| auto l_square = |
| BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision); |
| |
| // A - L*L.T |
| l_square = body_a - l_square; |
| auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); |
| if (ShapeUtil::ElementIsComplex(a_shape)) { |
| auto sqrt = Sqrt(Real(l_ii)); |
| l_ii = Complex(sqrt, ZerosLike(sqrt)); |
| seen_error = Or(seen_error, IsNan(sqrt)); |
| } else { |
| l_ii = Sqrt(l_ii); |
| seen_error = Or(seen_error, IsNan(l_ii)); |
| } |
| // L = (A - L*L.T) / l_ii * mask + L |
| body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; |
| |
| return std::vector<XlaOp>{body_a, body_l, seen_error}; |
| }; |
| |
| TF_ASSIGN_OR_RETURN( |
| auto cholesky_while, |
| ForEachIndex( |
| n, S32, body_fn, |
| {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))}, |
| "unblocked", builder)); |
| |
| return std::make_pair(cholesky_while[1], cholesky_while[2]); |
| } |
| |
| XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, |
| PrecisionConfig::Precision precision) { |
| XlaBuilder* builder = a.builder(); |
| return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { |
| TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); |
| const int ndims = a_shape.rank(); |
| if (ndims < 2) { |
| return InvalidArgument( |
| "Argument to Cholesky must have rank >= 2; shape was %s", |
| a_shape.ToString()); |
| } |
| |
| const int64 n = ShapeUtil::GetDimension(a_shape, -1); |
| if (n != ShapeUtil::GetDimension(a_shape, -2)) { |
| return InvalidArgument( |
| "Argument to Cholesky must be batched square matrices; got shape %s", |
| ShapeUtil::HumanString(a_shape)); |
| } |
| |
| if (block_size < 1) { |
| return InvalidArgument( |
| "block_size argument to Cholesky must be >= 1; got %d", block_size); |
| } |
| |
| std::vector<int64> error_dims(a_shape.dimensions().begin(), |
| a_shape.dimensions().end()); |
| error_dims.back() = error_dims.at(ndims - 2) = 1; |
| std::vector<int64> error_dim_indices(ndims); |
| absl::c_iota(error_dim_indices, 0); |
| |
| // Blocked left-looking Cholesky factorization. |
| // Algorithm 1 from |
| // Haidar, Azzam, et al. "High-performance Cholesky factorization for |
| // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. |
| XlaOp l = ZerosLike(a); |
| XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims)); |
| for (int64 i = 0; i < n; i += block_size) { |
| int64 k = std::min(block_size, n - i); |
| auto panel = SliceInMinorDims(a, {i, i}, {n, i + k}); |
| if (i > 0) { |
| // TODO(phawkins): consider implementing SYRK for the diagonal part of |
| // the panel. |
| // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) |
| auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); |
| auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); |
| auto delta = |
| BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision); |
| panel = panel - delta; |
| } |
| |
| // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) |
| auto x = SliceInMinorDims(panel, {0, 0}, {k, k}); |
| XlaOp factorized; |
| // TODO(b/167896062): A failure in one element of a batch shouldn't fail |
| // other elements. |
| XlaOp factorized_error; |
| if (k == 1) { |
| if (ShapeUtil::ElementIsComplex(a_shape)) { |
| auto sqrt = Sqrt(Real(x)); |
| factorized = Complex(sqrt, ZerosLike(sqrt)); |
| factorized_error = IsNan(sqrt); |
| } else { |
| factorized = Sqrt(x); |
| factorized_error = IsNan(factorized); |
| } |
| } else { |
| TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision)); |
| std::tie(factorized, factorized_error) = tile_output; |
| } |
| seen_error = Or(seen_error, factorized_error); |
| l = UpdateSliceInMinorDims(l, factorized, {i, i}); |
| |
| if (i + k < n) { |
| // l[i+k:, i:i+k] = |
| // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) |
| auto update = TriangularSolve( |
| factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}), |
| /*left_side=*/false, |
| /*lower=*/true, |
| /*unit_diagonal=*/false, |
| /*transpose_a=*/TriangularSolveOptions::ADJOINT); |
| l = UpdateSliceInMinorDims(l, update, {i + k, i}); |
| } |
| } |
| return Select( |
| BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices), |
| FullLike(l, std::numeric_limits<float>::quiet_NaN()), l); |
| }); |
| } |
| |
| bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { |
| return instruction->opcode() == HloOpcode::kCholesky; |
| } |
| |
| StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction( |
| HloInstruction* instruction) { |
| const CholeskyOptions& options = instruction->cholesky_options(); |
| const string name = absl::StrFormat( |
| "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(), |
| options.lower() ? "lower" : "upper"); |
| |
| HloModule* module = instruction->parent()->parent(); |
| |
| HloComputation*& computation = |
| computation_cache_.emplace(name, nullptr).first->second; |
| if (!computation) { |
| // Builds a new expansion. |
| // |
| // TODO(b/62327888): We do something unusual here: we build the computation |
| // using the XlaBuilder API, which is nominally an XLA client API. We do |
| // this because the external APIs for building complicated computations |
| // (XlaBuilder) are much more ergonomic than the internal ones. As it turns |
| // out, XlaBuilder isn't really a client API—what it does is build a |
| // HloModuleProto protocol buffer, that we can then deserialize and clone |
| // into our HloModule. Ideally we would avoid the protocol buffer step; |
| // that is left as an exercise for future work. |
| XlaBuilder builder(name); |
| XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); |
| XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()), |
| /*block_size=*/128, |
| /*precision=*/PrecisionConfig::HIGHEST); |
| MaybeTransposeInMinorDims(l, !options.lower()); |
| |
| TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); |
| |
| TF_ASSIGN_OR_RETURN(ProgramShape program_shape, |
| xla_computation.GetProgramShape()); |
| HloModuleConfig config(program_shape); |
| TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( |
| xla_computation.proto(), config)); |
| HloCloneContext context(module); |
| computation = |
| module->DeepCloneComputation(new_module->entry_computation(), &context); |
| } |
| |
| return instruction->parent()->AddInstruction(HloInstruction::CreateCall( |
| instruction->shape(), instruction->operands(), computation)); |
| } |
| |
| } // namespace xla |