[XLA] Fix elemental_ir_emitter access control:
* All data members are private now, and accessed through methods by derived classes.
* MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution.
PiperOrigin-RevId: 338747947
Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index 05364a4..b15aa36 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -41,8 +41,8 @@
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- lhs = FPCast(lhs, b_->getFloatTy());
- rhs = FPCast(rhs, b_->getFloatTy());
+ lhs = FPCast(lhs, b()->getFloatTy());
+ rhs = FPCast(rhs, b()->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
@@ -55,7 +55,7 @@
}
// Create a function declaration.
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
- module_
+ module()
->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(),
rhs->getType())
.getCallee());
@@ -65,7 +65,7 @@
// Create an instruction to call the function.
llvm::Value* result = Call(function, {lhs, rhs});
if (cast_result_to_fp16) {
- result = FPCast(result, b_->getHalfTy());
+ result = FPCast(result, b()->getHalfTy());
}
return result;
}
@@ -77,7 +77,7 @@
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- value = FPCast(value, b_->getFloatTy());
+ value = FPCast(value, b()->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
@@ -90,7 +90,7 @@
}
// Create a function declaration.
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
- module_
+ module()
->getOrInsertFunction(function_name, value->getType(),
value->getType())
.getCallee());
@@ -100,26 +100,20 @@
// Create an instruction to call the function.
llvm::Value* result = Call(function, value);
if (cast_result_to_fp16) {
- result = FPCast(result, b_->getHalfTy());
+ result = FPCast(result, b()->getHalfTy());
}
return result;
}
-llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) {
- switch (hlo->opcode()) {
- case HloOpcode::kConvolution:
- return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
- return ir_emitter_->EmitElementalConvolution(
- Cast<HloConvolutionInstruction>(hlo),
- operand_to_generator.at(hlo->operand(0)),
- operand_to_generator.at(hlo->operand(1)), index);
- };
- default:
- return ElementalIrEmitter::MakeElementGenerator(hlo,
- operand_to_generator);
- }
+ const HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) {
+ return ir_emitter_->EmitElementalConvolution(
+ Cast<HloConvolutionInstruction>(hlo),
+ operand_to_generator.at(hlo->operand(0)),
+ operand_to_generator.at(hlo->operand(1)), index);
}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 4c3167e..fbf582d 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -31,18 +31,19 @@
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
IrEmitter* ir_emitter, llvm::Module* module)
- : ElementalIrEmitter(module_config, module, ir_emitter->b()),
+ : ElementalIrEmitter(module, ir_emitter->b()),
+ hlo_module_config_(module_config),
ir_emitter_(ir_emitter) {}
- llvm_ir::ElementGenerator MakeElementGenerator(
- const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) override;
-
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
+ StatusOr<llvm::Value*> EmitConvolution(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) override;
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
@@ -54,6 +55,7 @@
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
}
+ const HloModuleConfig& hlo_module_config_;
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 3a449b7..d3e00d0 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2486,6 +2486,10 @@
return EmitElementalReduce(reduce_instr, std::move(input_generators),
std::move(initial_value_generators), index);
};
+ case HloOpcode::kConvolution:
+ return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+ return EmitConvolution(hlo, operand_to_generator, index);
+ };
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
@@ -2730,6 +2734,13 @@
}
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) {
+ return Unimplemented("Elemental convolution is not implemented");
+}
+
// Evaluate polynomial using Horner's method.
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 365e3f5..5683315 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -26,7 +26,6 @@
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
@@ -39,22 +38,14 @@
using HloToElementGeneratorMap =
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
- ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
- llvm::Module* module, llvm::IRBuilder<>* b)
- : b_(b), module_(module), hlo_module_config_(hlo_module_config) {}
+ ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
+ : b_(b), module_(module) {}
virtual ~ElementalIrEmitter() = default;
- virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
- llvm::Value* operand_value);
-
- virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
- llvm::Value* lhs_value,
- llvm::Value* rhs_value);
-
// Returns a function to generate an element of the output of `hlo`, given a
// map of functions to generate elements of its operands.
- virtual llvm_ir::ElementGenerator MakeElementGenerator(
+ llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator);
@@ -66,6 +57,21 @@
llvm::Module* module() { return module_; }
protected:
+ virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
+
+ virtual llvm::Value* EmitExtractReal(llvm::Value* value);
+ virtual llvm::Value* EmitExtractImag(llvm::Value* value);
+
+ private:
+ virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
+
+ virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
+
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
llvm::Value* operand_value);
@@ -92,10 +98,6 @@
llvm::Value* rhs_value,
bool is_signed);
- virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
- llvm::Value* lhs_value,
- llvm::Value* rhs_value);
-
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value);
@@ -175,9 +177,6 @@
PrimitiveType prim_type,
llvm::Value* operand_value);
- virtual llvm::Value* EmitExtractReal(llvm::Value* value);
- virtual llvm::Value* EmitExtractImag(llvm::Value* value);
-
// Composes a complex struct. imag may be nullptr for simple cast operations.
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
llvm::Value* imag);
@@ -245,17 +244,11 @@
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
const llvm_ir::IrArray::Index& index);
- virtual bool fast_min_max() = 0;
+ virtual StatusOr<llvm::Value*> EmitConvolution(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index);
- llvm::IRBuilder<>* const b_;
-
- llvm::Module* module_;
-
- // The HloModuleConfig which gathers all settings and values which affect the
- // compiled executable outside of the HLO code itself.
- const HloModuleConfig& hlo_module_config_;
-
- private:
// Computes the complex power function, returns (a + i*b)^(c + i*d).
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
llvm::Value* a, llvm::Value* b,
@@ -264,6 +257,12 @@
// Evaluates a polynomial using Horner's method.
StatusOr<llvm::Value*> EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
+
+ virtual bool fast_min_max() = 0;
+
+ llvm::IRBuilder<>* const b_;
+
+ llvm::Module* module_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 3f000a2..e72c128 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -72,7 +72,8 @@
GpuElementalIrEmitter::GpuElementalIrEmitter(
const HloModuleConfig& hlo_module_config, llvm::Module* module,
llvm::IRBuilder<>* b, NestedComputer compute_nested)
- : ElementalIrEmitter(hlo_module_config, module, b),
+ : ElementalIrEmitter(module, b),
+ hlo_module_config_(hlo_module_config),
compute_nested_(std::move(compute_nested)) {}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
@@ -91,7 +92,7 @@
for (int64 i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
converted_operands[i] =
- FPCast(converted_operands[i], b_->getFloatTy());
+ FPCast(converted_operands[i], b()->getFloatTy());
converted_input_types[i] = F32;
}
}
@@ -106,12 +107,12 @@
PrimitiveType_Name(output_type));
}
const string& munged_callee =
- ObtainDeviceFunctionName(funcid, output_type, b_);
+ ObtainDeviceFunctionName(funcid, output_type, b());
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type)
.ValueOrDie();
if (cast_result_to_fp16) {
- result = FPCast(result, b_->getHalfTy());
+ result = FPCast(result, b()->getHalfTy());
}
return result;
}
@@ -153,7 +154,7 @@
return EmitDeviceFunctionCall(
callee_name, operands, input_types, output_type,
- {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_);
+ {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
@@ -168,7 +169,7 @@
return llvm_ir::EmitCallToIntrinsic(
opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
: llvm::Intrinsic::minnum,
- {lhs_value, rhs_value}, {lhs_value->getType()}, b_);
+ {lhs_value, rhs_value}, {lhs_value->getType()}, b());
}
switch (op->opcode()) {
@@ -275,19 +276,19 @@
// This routine isn't numerically precise, but it's good enough for ML.
// Upcast F16 to F32 if necessary.
- llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
+ llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
llvm::Value* input = FPCast(value, type);
// If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
constexpr double kMaxValue = 20.0;
auto max_value = llvm::ConstantFP::get(type, kMaxValue);
llvm::Value* abs_value =
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_);
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
- llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+ llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
auto one = llvm::ConstantFP::get(type, 1.0);
auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
- {one, input}, {type}, b_);
+ {one, input}, {type}, b());
return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
value->getType(), "tanh");
}
@@ -301,14 +302,14 @@
llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
llvm::Value* block_id = IntCast(
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "block.id");
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
+ b()->getIntNTy(128), /*isSigned=*/true, "block.id");
llvm::Value* thread_id_in_block = IntCast(
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
+ b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
llvm::Value* threads_per_block = IntCast(
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+ EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
+ b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 766a4c8..0303ea4 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -126,6 +126,8 @@
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
+ const HloModuleConfig& hlo_module_config_;
+
NestedComputer compute_nested_;
};