Merge "Fix rs_kernel_context_t type mismatch in call to kernel."
diff --git a/lib/Renderscript/RSForEachExpand.cpp b/lib/Renderscript/RSForEachExpand.cpp
index 4d8d823..d1eda6f 100644
--- a/lib/Renderscript/RSForEachExpand.cpp
+++ b/lib/Renderscript/RSForEachExpand.cpp
@@ -389,6 +389,44 @@
     return AfterBB;
   }
 
+  // Finish building the outgoing argument list for calling a ForEach-able function.
+  //
+  // ArgVector - on input, the non-special arguments
+  //             on output, the non-special arguments combined with the special arguments
+  //               from SpecialArgVector
+  // SpecialArgVector - special arguments (from ExpandSpecialArguments())
+  // SpecialArgContextIdx - return value of ExpandSpecialArguments()
+  //                          (position of context argument in SpecialArgVector)
+  // CalleeFunction - the ForEach-able function being called
+  // Builder - for inserting code into the caller function
+  template<unsigned int ArgVectorLen, unsigned int SpecialArgVectorLen>
+  void finishArgList(      llvm::SmallVector<llvm::Value *, ArgVectorLen>        &ArgVector,
+                     const llvm::SmallVector<llvm::Value *, SpecialArgVectorLen> &SpecialArgVector,
+                     const int SpecialArgContextIdx,
+                     const llvm::Function &CalleeFunction,
+                     llvm::IRBuilder<> &CallerBuilder) {
+    /* The context argument (if any) is a pointer to an opaque user-visible type that differs from
+     * the RsExpandKernelDriverInfoPfx type used in the function we are generating (although the
+     * two types represent the same thing).  Therefore, we must introduce a pointer cast when
+     * generating a call to the kernel function.
+     */
+    const int ArgContextIdx =
+        SpecialArgContextIdx >= 0 ? (ArgVector.size() + SpecialArgContextIdx) : SpecialArgContextIdx;
+    ArgVector.append(SpecialArgVector.begin(), SpecialArgVector.end());
+    if (ArgContextIdx >= 0) {
+      llvm::Type *ContextArgType = nullptr;
+      int ArgIdx = ArgContextIdx;
+      for (const auto &Arg : CalleeFunction.getArgumentList()) {
+        if (!ArgIdx--) {
+          ContextArgType = Arg.getType();
+          break;
+        }
+      }
+      bccAssert(ContextArgType);
+      ArgVector[ArgContextIdx] = CallerBuilder.CreatePointerCast(ArgVector[ArgContextIdx], ContextArgType);
+    }
+  }
+
 public:
   RSForEachExpandPass(bool pEnableStepOpt = true)
       : ModulePass(ID), Module(nullptr), Context(nullptr),
@@ -410,16 +448,24 @@
   //            suitable for computing arguments for the ForEach-able function
   // CalleeArgs - contribution is accumulated here
   // Bump - invoked once for each contributed outgoing argument
-  void ExpandSpecialArguments(uint32_t Signature,
-                              llvm::Value *X,
-                              llvm::Value *Arg_p,
-                              llvm::IRBuilder<> &Builder,
-                              llvm::SmallVector<llvm::Value*, 8> &CalleeArgs,
-                              std::function<void ()> Bump) {
+  //
+  // Return value is the (zero-based) position of the context (Arg_p)
+  // argument in the CalleeArgs vector, or a negative value if the
+  // context argument is not placed in the CalleeArgs vector.
+  int ExpandSpecialArguments(uint32_t Signature,
+                             llvm::Value *X,
+                             llvm::Value *Arg_p,
+                             llvm::IRBuilder<> &Builder,
+                             llvm::SmallVector<llvm::Value*, 8> &CalleeArgs,
+                             std::function<void ()> Bump) {
 
+    bccAssert(CalleeArgs.empty());
+
+    int Return = -1;
     if (bcinfo::MetadataExtractor::hasForEachSignatureCtxt(Signature)) {
       CalleeArgs.push_back(Arg_p);
       Bump();
+      Return = CalleeArgs.size() - 1;
     }
 
     if (bcinfo::MetadataExtractor::hasForEachSignatureX(Signature)) {
@@ -447,6 +493,8 @@
         Bump();
       }
     }
+
+    return Return;
   }
 
   /* Performs the actual optimization on a selected function. On success, the
@@ -540,8 +588,8 @@
     createLoop(Builder, Arg_x1, Arg_x2, &IV);
 
     llvm::SmallVector<llvm::Value*, 8> CalleeArgs;
-    ExpandSpecialArguments(Signature, IV, Arg_p, Builder, CalleeArgs,
-                           [&FunctionArgIter]() { FunctionArgIter++; });
+    const int CalleeArgsContextIdx = ExpandSpecialArguments(Signature, IV, Arg_p, Builder, CalleeArgs,
+                                                            [&FunctionArgIter]() { FunctionArgIter++; });
 
     bccAssert(FunctionArgIter == Function->arg_end());
 
@@ -585,7 +633,7 @@
       RootArgs.push_back(UsrData);
     }
 
-    RootArgs.append(CalleeArgs.begin(), CalleeArgs.end());
+    finishArgList(RootArgs, CalleeArgs, CalleeArgsContextIdx, *Function, Builder);
 
     Builder.CreateCall(Function, RootArgs);
 
@@ -698,8 +746,8 @@
     createLoop(Builder, Arg_x1, Arg_x2, &IV);
 
     llvm::SmallVector<llvm::Value*, 8> CalleeArgs;
-    ExpandSpecialArguments(Signature, IV, Arg_p, Builder, CalleeArgs,
-                           [&NumInputs]() { --NumInputs; });
+    const int CalleeArgsContextIdx = ExpandSpecialArguments(Signature, IV, Arg_p, Builder, CalleeArgs,
+                                                            [&NumInputs]() { --NumInputs; });
 
     llvm::SmallVector<llvm::Type*,  8> InTypes;
     llvm::SmallVector<llvm::Value*, 8> InSteps;
@@ -814,7 +862,7 @@
       }
     }
 
-    RootArgs.append(CalleeArgs.begin(), CalleeArgs.end());
+    finishArgList(RootArgs, CalleeArgs, CalleeArgsContextIdx, *Function, Builder);
 
     llvm::Value *RetVal = Builder.CreateCall(Function, RootArgs);