blob: ceea3e28be2292dbaffe5a69b79587eadd526215 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
#include "tensorflow/compiler/xla/status.h"
namespace xla {
namespace mlir_gpu {
// Implementation for the translation of HLO instructions to a ThunkSequence
// via MLIR using the LHLO dialect.
// Implements the DfsHloVisitor interface, emits LHLO computations as MLIR IR
// functions and transforms them into gpu::Thunk.
class LhloDialectEmitter : public DfsHloVisitorWithDefault,
private gpu::ThunkEmitter::EmissionContext {
public:
LhloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context,
const BufferAssignment& assignment,
const se::Platform* platform,
::mlir::ModuleOp mlir_module);
~LhloDialectEmitter() override = default;
Status EmitComputation(const HloComputation& computation);
// The following methods implement the DfsHloVisitor interface.
//
// Default action which emits code for most operations. Operations which are
// special in some way are handled explicitly in HandleFoo methods.
Status DefaultAction(HloInstruction* instr) override;
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleCompare(HloInstruction* compare) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleIota(HloInstruction* iota) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status FinishVisit(HloInstruction* root) override;
// Transfers the ownship of thunk_sequence_ out.
std::unique_ptr<gpu::ThunkSequence> ConsumeThunkSequence() {
return std::move(thunk_sequence_);
}
const absl::flat_hash_map<const xla::HloInstruction*, ::mlir::FuncOp>&
InstructionToFunctionMap() const {
return instruction_to_mlir_func_;
}
private:
StatusOr<::mlir::FuncOp> CreateFunction(const HloInstruction& instr);
// Interface required by ThunkEmitter
void AddThunkToThunkSequence(std::unique_ptr<gpu::Thunk> thunk) override;
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
const HloInstruction& hlo, const ShapeIndex& index) const override;
int64 ByteSizeOf(const Shape& shape) const override;
const se::Platform* platform() const override;
mlir::Location getLocation(const HloInstruction* instr) const;
xla::mlir_gpu::EmissionContext* emission_context_;
::mlir::ModuleOp mlir_module_;
::mlir::Builder builder_;
absl::flat_hash_map<const xla::HloInstruction*, ::mlir::FuncOp>
instruction_to_mlir_func_;
const BufferAssignment& buffer_assignment_;
const se::Platform* platform_;
// Cached pointer size extracted from the mlir module.
unsigned pointer_size_;
// The thunk sequence this IrEmitter generates for the input computation.
std::unique_ptr<gpu::ThunkSequence> thunk_sequence_;
TF_DISALLOW_COPY_AND_ASSIGN(LhloDialectEmitter);
};
} // namespace mlir_gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_