[XLA:CPU] Move elemental conv emitter to generic code path so it can share the multiply-add logic with dot

This fixes some edge cases when it comes to complex numbers and also would
allow using it for ints if we want that. GPU doesn't use this code path.

PiperOrigin-RevId: 351549079
Change-Id: I7a2f9e9758e270b62814c7e3e0419342c5b58196
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index eade412..2baa44c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4203,12 +4203,12 @@
     deps = [
         ":hlo",
         ":hlo_casting_utils",
-        ":hlo_module_config",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
         "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index b15aa36..a4566b1 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -105,15 +105,5 @@
   return result;
 }
 
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
-    const HloInstruction* hlo,
-    const HloToElementGeneratorMap& operand_to_generator,
-    const llvm_ir::IrArray::Index& index) {
-  return ir_emitter_->EmitElementalConvolution(
-      Cast<HloConvolutionInstruction>(hlo),
-      operand_to_generator.at(hlo->operand(0)),
-      operand_to_generator.at(hlo->operand(1)), index);
-}
-
 }  // namespace cpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index fbf582d..a002df2 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -40,10 +40,6 @@
                                    llvm::Value* rhs) override;
   StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
                                   llvm::Value* value) override;
-  StatusOr<llvm::Value*> EmitConvolution(
-      const HloInstruction* hlo,
-      const HloToElementGeneratorMap& operand_to_generator,
-      const llvm_ir::IrArray::Index& index) override;
 
   StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index e1765f4..7179c5a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -856,152 +856,6 @@
                           hlo_module_config_, target_machine_features_);
 }
 
-StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
-    const HloConvolutionInstruction* convolution,
-    const llvm_ir::ElementGenerator& input_generator,
-    const llvm_ir::ElementGenerator& kernel_generator,
-    const llvm_ir::IrArray::Index& index) {
-  const HloInstruction* lhs = convolution->operand(0);
-  const HloInstruction* rhs = convolution->operand(1);
-  const Window& window = convolution->window();
-
-  const ConvolutionDimensionNumbers& dnums =
-      convolution->convolution_dimension_numbers();
-  int num_spatial_dims = dnums.output_spatial_dimensions_size();
-  std::vector<llvm::Value*> output_spatial(num_spatial_dims);
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
-  }
-  llvm::Value* output_feature = index[dnums.output_feature_dimension()];
-  llvm::Value* batch = index[dnums.output_batch_dimension()];
-
-  // We will accumulate the products into this sum to calculate the output entry
-  // at the given index.
-  PrimitiveType lhs_element_type = lhs->shape().element_type();
-  llvm::Type* lhs_llvm_type =
-      llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
-  // Upcast the accumulator to F32 from F16 for increased precision.
-  llvm::Type* accumulator_type =
-      lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type;
-  llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
-      accumulator_type, "convolution_sum_address", &b_,
-      MinimumAlignmentForPrimitiveType(lhs_element_type));
-  llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
-  Store(constant_zero, sum_address);
-
-  llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
-  std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    kernel_spatial[i] =
-        loops
-            .AddLoop(
-                0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
-                absl::StrCat("k", i))
-            ->GetIndVarValue();
-  }
-  llvm::Value* input_feature =
-      loops
-          .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
-                   "iz")
-          ->GetIndVarValue();
-
-  SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
-
-  // Calculate the spatial index in the input array, taking striding, dilation
-  // and padding into account. An index in the padding will be out of the bounds
-  // of the array.
-  const auto calculate_input_index = [this](llvm::Value* output_index,
-                                            llvm::Value* kernel_index,
-                                            const WindowDimension& window_dim) {
-    llvm::Value* strided_index =
-        NSWMul(output_index, b_.getInt64(window_dim.stride()));
-    llvm::Value* dilated_kernel_index =
-        NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
-    return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
-                  b_.getInt64(window_dim.padding_low()));
-  };
-  std::vector<llvm::Value*> input_spatial(num_spatial_dims);
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    input_spatial[i] = calculate_input_index(
-        output_spatial[i], kernel_spatial[i], window.dimensions(i));
-  }
-
-  // We need to check if 0 <= input dim < bound, as otherwise we are in the
-  // padding so that we can skip the computation. That is equivalent to input
-  // dim < bound as an *unsigned* comparison, since a negative value will wrap
-  // to a large positive value. The input dim is dilated, so we need to dilate
-  // the bound as well to match.
-
-  // Also need to check that the input coordinates are not in one of the
-  // holes created by base dilation.
-  const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
-    llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
-    return ICmpEQ(remainder, b_.getInt64(0));
-  };
-
-  llvm::Value* in_bounds_condition = b_.getInt1(true);
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
-        lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
-        window.dimensions(i).base_dilation()));
-    llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
-    llvm::Value* dim_not_in_hole =
-        not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
-    llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
-    in_bounds_condition = And(in_bounds_condition, dim_ok);
-  }
-
-  // Now we need to map the dilated base coordinates back to the actual
-  // data indices on the lhs.
-  const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
-    return SDiv(input_index, b_.getInt64(base_dilation));
-  };
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    input_spatial[i] =
-        undilate(input_spatial[i], window.dimensions(i).base_dilation());
-  }
-
-  llvm_ir::LlvmIfData if_data =
-      llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
-  SetToFirstInsertPoint(if_data.true_block, &b_);
-
-  // We are not in the padding, so carry out the computation.
-  int num_dims = num_spatial_dims + 2;
-  std::vector<llvm::Value*> input_multi_index(num_dims);
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
-  }
-  input_multi_index[dnums.input_feature_dimension()] = input_feature;
-  input_multi_index[dnums.input_batch_dimension()] = batch;
-
-  std::vector<llvm::Value*> kernel_multi_index(num_dims);
-  for (int i = 0; i < num_spatial_dims; ++i) {
-    kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
-        window.dimensions(i).window_reversal()
-            ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
-                     kernel_spatial[i])
-            : kernel_spatial[i];
-  }
-
-  kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
-  kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
-
-  llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
-                                      b_.getInt64Ty());
-  TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
-                      input_generator(input_index));
-  llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
-                                       b_.getInt64Ty());
-  TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
-                      kernel_generator(kernel_index));
-  llvm::Value* product = FMul(input_value, kernel_value);
-  llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type));
-  Store(sum, sum_address);
-
-  SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
-  return FPCast(Load(sum_address), lhs_llvm_type);
-}
-
 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
   auto lhs = convolution->operand(0);
   auto rhs = convolution->operand(1);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 891d53c..49490ef 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -121,13 +121,6 @@
   // Emit an LLVM global variable for every constant buffer allocation.
   Status EmitConstantGlobals();
 
-  // Emit code to emit the element at `index` for a convolution instruction.
-  StatusOr<llvm::Value*> EmitElementalConvolution(
-      const HloConvolutionInstruction* convolution,
-      const llvm_ir::ElementGenerator& input_generator,
-      const llvm_ir::ElementGenerator& kernel_generator,
-      const llvm_ir::IrArray::Index& index);
-
  protected:
   //
   // The following methods implement the DfsHloVisitor interface.
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 735b0b7..817d3e6 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -42,6 +42,7 @@
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/logging.h"
@@ -2222,27 +2223,8 @@
   llvm::Value* current_accumulator = Load(accumulator_alloca);
   TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
   TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
-  llvm::Value* next_accumulator;
-  if (primitive_util::IsComplexType(primitive_type)) {
-    llvm::Value* product_real =
-        FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
-             FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
-    llvm::Value* product_imag =
-        FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
-             FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
-    next_accumulator = InsertValue(
-        current_accumulator,
-        FAdd(EmitExtractReal(current_accumulator), product_real), {0});
-    next_accumulator = InsertValue(
-        next_accumulator,
-        FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
-  } else if (primitive_util::IsFloatingPointType(primitive_type)) {
-    next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
-  } else if (primitive_type == PRED) {
-    next_accumulator = Or(current_accumulator, And(lhs_value, rhs_value));
-  } else {
-    next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
-  }
+  llvm::Value* next_accumulator =
+      EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
   Store(next_accumulator, accumulator_alloca);
 
   SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
@@ -2551,6 +2533,28 @@
   return complex;
 }
 
+llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
+                                            llvm::Value* accumulator,
+                                            xla::PrimitiveType primitive_type) {
+  if (primitive_util::IsComplexType(primitive_type)) {
+    llvm::Value* product_real =
+        FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
+             FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
+    llvm::Value* product_imag =
+        FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
+             FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
+    llvm::Value* next_accumulator = InsertValue(
+        accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
+    return InsertValue(next_accumulator,
+                       FAdd(EmitExtractImag(accumulator), product_imag), {1});
+  } else if (primitive_util::IsFloatingPointType(primitive_type)) {
+    return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
+  } else if (primitive_type == PRED) {
+    return Or(accumulator, And(lhs, rhs));
+  }
+  return Add(accumulator, Mul(lhs, rhs));
+}
+
 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
     const HloMapInstruction* map_instr,
     absl::Span<llvm::Value* const> elemental_operands) {
@@ -2767,10 +2771,149 @@
 }
 
 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
-    const HloInstruction* hlo,
+    const HloInstruction* convolution,
     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
     const llvm_ir::IrArray::Index& index) {
-  return Unimplemented("Elemental convolution is not implemented");
+  const HloInstruction* lhs = convolution->operand(0);
+  const auto& input_generator = operand_to_generator.at(lhs);
+  const HloInstruction* rhs = convolution->operand(1);
+  const auto& kernel_generator = operand_to_generator.at(rhs);
+  const Window& window = convolution->window();
+
+  const ConvolutionDimensionNumbers& dnums =
+      convolution->convolution_dimension_numbers();
+  int num_spatial_dims = dnums.output_spatial_dimensions_size();
+  std::vector<llvm::Value*> output_spatial(num_spatial_dims);
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
+  }
+  llvm::Value* output_feature = index[dnums.output_feature_dimension()];
+  llvm::Value* batch = index[dnums.output_batch_dimension()];
+
+  // We will accumulate the products into this sum to calculate the output entry
+  // at the given index.
+  PrimitiveType lhs_element_type = lhs->shape().element_type();
+  llvm::Type* lhs_llvm_type =
+      llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
+  // Upcast the accumulator to F32 from F16 for increased precision.
+  llvm::Type* accumulator_type =
+      lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
+  llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
+      accumulator_type, "convolution_sum_address", b_);
+  llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
+  Store(constant_zero, sum_address);
+
+  llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
+  std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    kernel_spatial[i] =
+        loops
+            .AddLoop(
+                0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
+                absl::StrCat("k", i))
+            ->GetIndVarValue();
+  }
+  llvm::Value* input_feature =
+      loops
+          .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
+                   "iz")
+          ->GetIndVarValue();
+
+  SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
+
+  // Calculate the spatial index in the input array, taking striding, dilation
+  // and padding into account. An index in the padding will be out of the bounds
+  // of the array.
+  const auto calculate_input_index = [this](llvm::Value* output_index,
+                                            llvm::Value* kernel_index,
+                                            const WindowDimension& window_dim) {
+    llvm::Value* strided_index =
+        NSWMul(output_index, b_->getInt64(window_dim.stride()));
+    llvm::Value* dilated_kernel_index =
+        NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
+    return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
+                  b_->getInt64(window_dim.padding_low()));
+  };
+  std::vector<llvm::Value*> input_spatial(num_spatial_dims);
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    input_spatial[i] = calculate_input_index(
+        output_spatial[i], kernel_spatial[i], window.dimensions(i));
+  }
+
+  // We need to check if 0 <= input dim < bound, as otherwise we are in the
+  // padding so that we can skip the computation. That is equivalent to input
+  // dim < bound as an *unsigned* comparison, since a negative value will wrap
+  // to a large positive value. The input dim is dilated, so we need to dilate
+  // the bound as well to match.
+
+  // Also need to check that the input coordinates are not in one of the
+  // holes created by base dilation.
+  const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
+    llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
+    return ICmpEQ(remainder, b_->getInt64(0));
+  };
+
+  llvm::Value* in_bounds_condition = b_->getInt1(true);
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
+        lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
+        window.dimensions(i).base_dilation()));
+    llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
+    llvm::Value* dim_not_in_hole =
+        not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
+    llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
+    in_bounds_condition = And(in_bounds_condition, dim_ok);
+  }
+
+  // Now we need to map the dilated base coordinates back to the actual
+  // data indices on the lhs.
+  const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
+    return SDiv(input_index, b_->getInt64(base_dilation));
+  };
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    input_spatial[i] =
+        undilate(input_spatial[i], window.dimensions(i).base_dilation());
+  }
+
+  llvm_ir::LlvmIfData if_data =
+      llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
+  SetToFirstInsertPoint(if_data.true_block, b_);
+
+  // We are not in the padding, so carry out the computation.
+  int num_dims = num_spatial_dims + 2;
+  std::vector<llvm::Value*> input_multi_index(num_dims);
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
+  }
+  input_multi_index[dnums.input_feature_dimension()] = input_feature;
+  input_multi_index[dnums.input_batch_dimension()] = batch;
+
+  std::vector<llvm::Value*> kernel_multi_index(num_dims);
+  for (int i = 0; i < num_spatial_dims; ++i) {
+    kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
+        window.dimensions(i).window_reversal()
+            ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
+                     kernel_spatial[i])
+            : kernel_spatial[i];
+  }
+
+  kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
+  kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
+
+  llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
+                                      b_->getInt64Ty());
+  TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
+                      input_generator(input_index));
+  llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
+                                       b_->getInt64Ty());
+  TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
+                      kernel_generator(kernel_index));
+  llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
+                                convolution->shape().element_type());
+  Store(sum, sum_address);
+
+  SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
+  return FPCast(Load(sum_address), lhs_llvm_type);
 }
 
 // Evaluate polynomial using Horner's method.
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 60e25c7..5cf368f 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -183,6 +183,11 @@
   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
                                   llvm::Value* imag);
 
+  // Emit `accumulator + lhs * rhs` for the given primitive type.
+  llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
+                          llvm::Value* accumulator,
+                          xla::PrimitiveType primitive_type);
+
   // Identifier of the thread unique among all threads on the device
   virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
 
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 8337f93..2802c66 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -1644,6 +1644,18 @@
   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
 }
 
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(ConvolveC64Forward)) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+  %arg0 = c64[3,56,56,16] parameter(0)
+  %arg1 = c64[3,3,3,64] parameter(1)
+  ROOT %conv = c64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
+})";
+  EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
+}
+
 XLA_TEST_F(ConvolutionHloTest,
            DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) {
   constexpr char kHlo[] = R"(