blob: f8db5a70360a6828e7017621fa593abcb922db7b [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
#include "llvm/Support/MemoryBuffer.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
namespace {
// Error collector that simply ignores errors reported.
class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector {
public:
void AddError(int line, int column, const std::string& message) override {}
};
bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
tensorflow::protobuf::TextFormat::Parser parser;
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
return hlo_proto->ParseFromString(contents) ||
parser.ParseFromString(contents, hlo_proto) ||
hlo_proto->mutable_hlo_module()->ParseFromString(contents) ||
parser.ParseFromString(contents, hlo_proto->mutable_hlo_module());
}
} // namespace
mlir::OwningOpRef<mlir::ModuleOp> HloToMlirHloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context,
bool import_all_computations) {
HloProto hlo_proto;
std::string content(input.data(), input.size());
if (!LoadHloProto(content, &hlo_proto)) {
LOG(ERROR) << "Failed to load proto";
return nullptr;
}
mlir::OwningOpRef<mlir::ModuleOp> module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
auto status = ConvertHloToMlirHlo(
module.get(), hlo_proto.mutable_hlo_module(), import_all_computations);
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return nullptr;
}
return module;
}
mlir::OwningOpRef<mlir::ModuleOp> HloTextToMlirHloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context,
bool import_all_computations) {
std::string content(input.data(), input.size());
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
if (!hlo_module_error.ok()) {
LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status();
return nullptr;
}
auto hlo_module = std::move(hlo_module_error.ValueOrDie());
mlir::OwningOpRef<mlir::ModuleOp> module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
auto status =
ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations);
if (!status.ok()) {
LOG(ERROR) << "HLO Module import failed: " << status;
return nullptr;
}
return module;
}
} // namespace xla