blob: ba8bbd238d03c71b336b12a9698903f252400afb [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/dnn.h"
namespace xla {
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
mlir::mhlo::ConvDimensionNumbersAttr input) {
ConvolutionDimensionNumbers output;
output.set_input_batch_dimension(input.getInputBatchDimension());
output.set_input_feature_dimension(input.getInputFeatureDimension());
for (auto v : input.getInputSpatialDimensions()) {
output.add_input_spatial_dimensions(v);
}
output.set_kernel_input_feature_dimension(
input.getKernelInputFeatureDimension());
output.set_kernel_output_feature_dimension(
input.getKernelOutputFeatureDimension());
for (auto v : input.getKernelSpatialDimensions()) {
output.add_kernel_spatial_dimensions(v);
}
output.set_output_batch_dimension(input.getOutputBatchDimension());
output.set_output_feature_dimension(input.getOutputFeatureDimension());
for (auto v : input.getOutputSpatialDimensions()) {
output.add_output_spatial_dimensions(v);
}
return output;
}
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
mlir::lmhlo_gpu::Activation activation) {
switch (activation) {
case mlir::lmhlo_gpu::Activation::None:
return stream_executor::dnn::kNone;
case mlir::lmhlo_gpu::Activation::Sigmoid:
return stream_executor::dnn::kSigmoid;
case mlir::lmhlo_gpu::Activation::Tanh:
return stream_executor::dnn::kTanh;
case mlir::lmhlo_gpu::Activation::Relu:
return stream_executor::dnn::kRelu;
case mlir::lmhlo_gpu::Activation::Relu6:
return stream_executor::dnn::kRelu6;
case mlir::lmhlo_gpu::Activation::ReluX:
return stream_executor::dnn::kReluX;
case mlir::lmhlo_gpu::Activation::BandPass:
return stream_executor::dnn::kBandPass;
default:
return InternalError("Unexpected activation");
}
}
// Convert replica group from MLIR encoding to HLO.
// See HloFunctionImporter::ConvertReplicaGroups for the MLIR encoding.
StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
mlir::DenseIntElementsAttr input) {
mlir::RankedTensorType type =
input.getType().dyn_cast<mlir::RankedTensorType>();
if (!type || type.getRank() != 2 ||
!type.getElementType().isInteger(/*width=*/64)) {
return InternalError("Execpted replica group to be a rank 2 tensor of i64");
}
// rank 0 is num_groups, rank 1 is group size.
auto replica_group_values_it = input.getValues<uint64_t>().begin();
std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
for (ReplicaGroup& group : replica_groups) {
for (int64_t element_idx = 0; element_idx < type.getDimSize(1);
++element_idx, ++replica_group_values_it) {
// For replica group attribute, -1 indicates padding added by
// HloFunctionImporter::ConvertReplicaGroups. This should always be at the
// end and can be dropped when converting back to XLA HLO ReplicaGroups.
if (*replica_group_values_it != -1) {
group.add_replica_ids(*replica_group_values_it);
}
}
}
return replica_groups;
}
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
// and source-target pairs are defined in HLO.
StatusOr<std::vector<std::pair<int64_t, int64_t>>> ConvertNx2Attribute(
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
if (!optional_attr.has_value())
return std::vector<std::pair<int64_t, int64_t>>{};
mlir::DenseIntElementsAttr attr = *optional_attr;
auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
auto it = attr.getValues<int64_t>().begin();
std::vector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
for (auto& item : out) {
int64_t first = *it;
++it;
int64_t second = *it;
++it;
item = {first, second};
}
return out;
}
StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
llvm::Optional<mlir::mhlo::FftType> type =
mlir::mhlo::symbolizeEnum<mlir::mhlo::FftType>(type_string);
if (!type) return InvalidArgument("Unknown FFT type %s", type_string.str());
switch (*type) {
case mlir::mhlo::FftType::FFT:
return xla::FftType::FFT;
case mlir::mhlo::FftType::IFFT:
return xla::FftType::IFFT;
case mlir::mhlo::FftType::RFFT:
return xla::FftType::RFFT;
case mlir::mhlo::FftType::IRFFT:
return xla::FftType::IRFFT;
default:
return InvalidArgument("Unknown FFT type enum #%d", *type);
}
}
StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
llvm::StringRef transpose_string) {
llvm::Optional<mlir::mhlo::Transpose> transpose =
mlir::mhlo::symbolizeTranspose(transpose_string);
if (!transpose)
return InvalidArgument("Unknown transpose type %s", transpose_string.str());
switch (*transpose) {
case mlir::mhlo::Transpose::NO_TRANSPOSE:
return TriangularSolveOptions::NO_TRANSPOSE;
case mlir::mhlo::Transpose::TRANSPOSE:
return TriangularSolveOptions::TRANSPOSE;
case mlir::mhlo::Transpose::ADJOINT:
return TriangularSolveOptions::ADJOINT;
case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
return TriangularSolveOptions::TRANSPOSE_INVALID;
default:
return InvalidArgument("Unknown transpose enum value #%d", *transpose);
}
}
StatusOr<xla::CustomCallApiVersion> ConvertCustomCallApiVersion(
mlir::mhlo::CustomCallApiVersion api_version) {
switch (api_version) {
case mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED:
return xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED;
case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
return xla::CustomCallApiVersion::API_VERSION_ORIGINAL;
case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
return xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING;
case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
return xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED;
default:
return InvalidArgument("Unknown CustomCallApiVersion enum value #%d",
api_version);
}
}
} // namespace xla