[XLA] Fix elemental_ir_emitter access control:
* All data members are private now, and accessed through methods by derived classes.
* MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution.

PiperOrigin-RevId: 338747947
Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index 05364a4..b15aa36 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -41,8 +41,8 @@
   switch (prim_type) {
     case F16:
       cast_result_to_fp16 = true;
-      lhs = FPCast(lhs, b_->getFloatTy());
-      rhs = FPCast(rhs, b_->getFloatTy());
+      lhs = FPCast(lhs, b()->getFloatTy());
+      rhs = FPCast(rhs, b()->getFloatTy());
       TF_FALLTHROUGH_INTENDED;
     case F32:
       function_name = "atan2f";
@@ -55,7 +55,7 @@
   }
   // Create a function declaration.
   llvm::Function* function = llvm::dyn_cast<llvm::Function>(
-      module_
+      module()
           ->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(),
                                 rhs->getType())
           .getCallee());
@@ -65,7 +65,7 @@
   // Create an instruction to call the function.
   llvm::Value* result = Call(function, {lhs, rhs});
   if (cast_result_to_fp16) {
-    result = FPCast(result, b_->getHalfTy());
+    result = FPCast(result, b()->getHalfTy());
   }
   return result;
 }
@@ -77,7 +77,7 @@
   switch (prim_type) {
     case F16:
       cast_result_to_fp16 = true;
-      value = FPCast(value, b_->getFloatTy());
+      value = FPCast(value, b()->getFloatTy());
       TF_FALLTHROUGH_INTENDED;
     case F32:
       function_name = "tanhf";
@@ -90,7 +90,7 @@
   }
   // Create a function declaration.
   llvm::Function* function = llvm::dyn_cast<llvm::Function>(
-      module_
+      module()
           ->getOrInsertFunction(function_name, value->getType(),
                                 value->getType())
           .getCallee());
@@ -100,26 +100,20 @@
   // Create an instruction to call the function.
   llvm::Value* result = Call(function, value);
   if (cast_result_to_fp16) {
-    result = FPCast(result, b_->getHalfTy());
+    result = FPCast(result, b()->getHalfTy());
   }
   return result;
 }
 
-llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
     const HloInstruction* hlo,
-    const HloToElementGeneratorMap& operand_to_generator) {
-  switch (hlo->opcode()) {
-    case HloOpcode::kConvolution:
-      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
-        return ir_emitter_->EmitElementalConvolution(
-            Cast<HloConvolutionInstruction>(hlo),
-            operand_to_generator.at(hlo->operand(0)),
-            operand_to_generator.at(hlo->operand(1)), index);
-      };
-    default:
-      return ElementalIrEmitter::MakeElementGenerator(hlo,
-                                                      operand_to_generator);
-  }
+    const HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) {
+  return ir_emitter_->EmitElementalConvolution(
+      Cast<HloConvolutionInstruction>(hlo),
+      operand_to_generator.at(hlo->operand(0)),
+      operand_to_generator.at(hlo->operand(1)), index);
 }
+
 }  // namespace cpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 4c3167e..fbf582d 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -31,18 +31,19 @@
  public:
   CpuElementalIrEmitter(const HloModuleConfig& module_config,
                         IrEmitter* ir_emitter, llvm::Module* module)
-      : ElementalIrEmitter(module_config, module, ir_emitter->b()),
+      : ElementalIrEmitter(module, ir_emitter->b()),
+        hlo_module_config_(module_config),
         ir_emitter_(ir_emitter) {}
 
-  llvm_ir::ElementGenerator MakeElementGenerator(
-      const HloInstruction* hlo,
-      const HloToElementGeneratorMap& operand_to_generator) override;
-
  protected:
   StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
                                    llvm::Value* rhs) override;
   StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
                                   llvm::Value* value) override;
+  StatusOr<llvm::Value*> EmitConvolution(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index) override;
 
   StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
@@ -54,6 +55,7 @@
     return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
   }
 
+  const HloModuleConfig& hlo_module_config_;
   IrEmitter* ir_emitter_;
 };
 
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 3a449b7..d3e00d0 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2486,6 +2486,10 @@
         return EmitElementalReduce(reduce_instr, std::move(input_generators),
                                    std::move(initial_value_generators), index);
       };
+    case HloOpcode::kConvolution:
+      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+        return EmitConvolution(hlo, operand_to_generator, index);
+      };
     default:
       return [hlo](const IrArray::Index& index) {
         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
@@ -2730,6 +2734,13 @@
   }
 }
 
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) {
+  return Unimplemented("Elemental convolution is not implemented");
+}
+
 // Evaluate polynomial using Horner's method.
 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
     llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 365e3f5..5683315 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -26,7 +26,6 @@
 #include "llvm/IR/Value.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
@@ -39,22 +38,14 @@
   using HloToElementGeneratorMap =
       std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
 
-  ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
-                     llvm::Module* module, llvm::IRBuilder<>* b)
-      : b_(b), module_(module), hlo_module_config_(hlo_module_config) {}
+  ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
+      : b_(b), module_(module) {}
 
   virtual ~ElementalIrEmitter() = default;
 
-  virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
-                                             llvm::Value* operand_value);
-
-  virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
-                                              llvm::Value* lhs_value,
-                                              llvm::Value* rhs_value);
-
   // Returns a function to generate an element of the output of `hlo`, given a
   // map of functions to generate elements of its operands.
-  virtual llvm_ir::ElementGenerator MakeElementGenerator(
+  llvm_ir::ElementGenerator MakeElementGenerator(
       const HloInstruction* hlo,
       const HloToElementGeneratorMap& operand_to_generator);
 
@@ -66,6 +57,21 @@
   llvm::Module* module() { return module_; }
 
  protected:
+  virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+                                                   llvm::Value* lhs_value,
+                                                   llvm::Value* rhs_value);
+
+  virtual llvm::Value* EmitExtractReal(llvm::Value* value);
+  virtual llvm::Value* EmitExtractImag(llvm::Value* value);
+
+ private:
+  virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
+                                             llvm::Value* operand_value);
+
+  virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
+                                              llvm::Value* lhs_value,
+                                              llvm::Value* rhs_value);
+
   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
                                                     llvm::Value* operand_value);
 
@@ -92,10 +98,6 @@
                                                      llvm::Value* rhs_value,
                                                      bool is_signed);
 
-  virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
-                                                   llvm::Value* lhs_value,
-                                                   llvm::Value* rhs_value);
-
   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
                                                      llvm::Value* lhs_value,
                                                      llvm::Value* rhs_value);
@@ -175,9 +177,6 @@
                                                   PrimitiveType prim_type,
                                                   llvm::Value* operand_value);
 
-  virtual llvm::Value* EmitExtractReal(llvm::Value* value);
-  virtual llvm::Value* EmitExtractImag(llvm::Value* value);
-
   // Composes a complex struct. imag may be nullptr for simple cast operations.
   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
                                   llvm::Value* imag);
@@ -245,17 +244,11 @@
       std::vector<llvm_ir::ElementGenerator> initial_value_generators,
       const llvm_ir::IrArray::Index& index);
 
-  virtual bool fast_min_max() = 0;
+  virtual StatusOr<llvm::Value*> EmitConvolution(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index);
 
-  llvm::IRBuilder<>* const b_;
-
-  llvm::Module* module_;
-
-  // The HloModuleConfig which gathers all settings and values which affect the
-  // compiled executable outside of the HLO code itself.
-  const HloModuleConfig& hlo_module_config_;
-
- private:
   // Computes the complex power function, returns (a + i*b)^(c + i*d).
   StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
                                           llvm::Value* a, llvm::Value* b,
@@ -264,6 +257,12 @@
   // Evaluates a polynomial using Horner's method.
   StatusOr<llvm::Value*> EvaluatePolynomial(
       llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
+
+  virtual bool fast_min_max() = 0;
+
+  llvm::IRBuilder<>* const b_;
+
+  llvm::Module* module_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 3f000a2..e72c128 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -72,7 +72,8 @@
 GpuElementalIrEmitter::GpuElementalIrEmitter(
     const HloModuleConfig& hlo_module_config, llvm::Module* module,
     llvm::IRBuilder<>* b, NestedComputer compute_nested)
-    : ElementalIrEmitter(hlo_module_config, module, b),
+    : ElementalIrEmitter(module, b),
+      hlo_module_config_(hlo_module_config),
       compute_nested_(std::move(compute_nested)) {}
 
 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
@@ -91,7 +92,7 @@
       for (int64 i = 0; i < operands.size(); ++i) {
         if (input_types[i] == F16) {
           converted_operands[i] =
-              FPCast(converted_operands[i], b_->getFloatTy());
+              FPCast(converted_operands[i], b()->getFloatTy());
           converted_input_types[i] = F32;
         }
       }
@@ -106,12 +107,12 @@
                            PrimitiveType_Name(output_type));
   }
   const string& munged_callee =
-      ObtainDeviceFunctionName(funcid, output_type, b_);
+      ObtainDeviceFunctionName(funcid, output_type, b());
   llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
                                      converted_input_types, output_type)
                             .ValueOrDie();
   if (cast_result_to_fp16) {
-    result = FPCast(result, b_->getHalfTy());
+    result = FPCast(result, b()->getHalfTy());
   }
   return result;
 }
@@ -153,7 +154,7 @@
 
   return EmitDeviceFunctionCall(
       callee_name, operands, input_types, output_type,
-      {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_);
+      {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
 }
 
 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
@@ -168,7 +169,7 @@
     return llvm_ir::EmitCallToIntrinsic(
         opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
                                       : llvm::Intrinsic::minnum,
-        {lhs_value, rhs_value}, {lhs_value->getType()}, b_);
+        {lhs_value, rhs_value}, {lhs_value->getType()}, b());
   }
 
   switch (op->opcode()) {
@@ -275,19 +276,19 @@
   // This routine isn't numerically precise, but it's good enough for ML.
 
   // Upcast F16 to F32 if necessary.
-  llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
+  llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
   llvm::Value* input = FPCast(value, type);
 
   // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
   constexpr double kMaxValue = 20.0;
   auto max_value = llvm::ConstantFP::get(type, kMaxValue);
   llvm::Value* abs_value =
-      llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_);
+      llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
 
-  llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+  llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
   auto one = llvm::ConstantFP::get(type, 1.0);
   auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
-                                                    {one, input}, {type}, b_);
+                                                    {one, input}, {type}, b());
   return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
                 value->getType(), "tanh");
 }
@@ -301,14 +302,14 @@
 
 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
   llvm::Value* block_id = IntCast(
-      EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_),
-      b_->getIntNTy(128), /*isSigned=*/true, "block.id");
+      EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
+      b()->getIntNTy(128), /*isSigned=*/true, "block.id");
   llvm::Value* thread_id_in_block = IntCast(
-      EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_),
-      b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+      EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
+      b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
   llvm::Value* threads_per_block = IntCast(
-      EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_),
-      b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+      EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
+      b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 766a4c8..0303ea4 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -126,6 +126,8 @@
       const string& callee_name, absl::Span<llvm::Value* const> operands,
       absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
 
+  const HloModuleConfig& hlo_module_config_;
+
   NestedComputer compute_nested_;
 };