Merge "Disallow launching old-style kernels via rsForEach"
diff --git a/slang_backend.cpp b/slang_backend.cpp
index c481a81..cb17cb4 100644
--- a/slang_backend.cpp
+++ b/slang_backend.cpp
@@ -429,17 +429,19 @@
       }
     }
 
-    if (getTargetAPI() >= SLANG_N_TARGET_API) {
+    if (getTargetAPI() >= SLANG_FEATURE_SINGLE_SOURCE_API) {
       if (FD && FD->hasBody() &&
           !Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
-        bool isKernel = RSExportForEach::isRSForEachFunc(getTargetAPI(), FD);
-        if (isKernel) {
-          // Log kernels by their names, and assign them slot numbers.
+        if (FD->hasAttr<clang::KernelAttr>()) {
+          // Log functions with attribute "kernel" by their names, and assign
+          // them slot numbers. Any other function cannot be used in a
+          // rsForEach() or rsForEachWithOptions() call, including old-style
+          // kernel functions which are defined without the "kernel" attribute.
           mContext->addForEach(FD);
         }
         // Look for any kernel launch calls and translate them into using the
         // internal API.
-        // Report a compiler on kernel launches inside a kernel.
+        // Report a compiler error on kernel launches inside a kernel.
         mForEachHandler.handleForEachCalls(FD, getTargetAPI());
       }
     }
diff --git a/slang_rs_context.cpp b/slang_rs_context.cpp
index 626cc0a..4c7971d 100644
--- a/slang_rs_context.cpp
+++ b/slang_rs_context.cpp
@@ -73,6 +73,7 @@
 
   // Reserve slot 0 for the root kernel.
   mExportForEach.push_back(nullptr);
+  mFirstOldStyleKernel = mExportForEach.end();
 }
 
 bool RSContext::processExportVar(const clang::VarDecl *VD) {
@@ -125,12 +126,25 @@
     if (EFE == nullptr) {
       return false;
     }
-    const llvm::StringRef& funcName = FD->getName();
-    if (funcName.equals("root")) {
+
+    // The root function should be at index 0 in the list
+    if (FD->getName().equals("root")) {
       mExportForEach[0] = EFE;
-    } else {
-      mExportForEach.push_back(EFE);
+      return true;
     }
+
+    // New-style kernels with attribute "kernel" should come first in the list
+    if (FD->hasAttr<clang::KernelAttr>()) {
+      mFirstOldStyleKernel = mExportForEach.insert(mFirstOldStyleKernel, EFE) + 1;
+      slangAssert((mTargetAPI < SLANG_FEATURE_SINGLE_SOURCE_API ||
+                   getForEachSlotNumber(FD->getName()) ==
+                   mFirstOldStyleKernel - mExportForEach.begin() - 1) &&
+                  "Inconsistent slot number assignment");
+      return true;
+    }
+
+    // Old-style kernels should appear in the end of the list
+    mFirstOldStyleKernel = mExportForEach.insert(mFirstOldStyleKernel, EFE);
     return true;
   }
 
diff --git a/slang_rs_context.h b/slang_rs_context.h
index 473e00f..45edf3a 100644
--- a/slang_rs_context.h
+++ b/slang_rs_context.h
@@ -117,6 +117,7 @@
   ExportFuncList mExportFuncs;
   std::map<llvm::StringRef, unsigned> mExportForEachMap;
   ExportForEachVector mExportForEach;
+  ExportForEachVector::iterator mFirstOldStyleKernel;
   ExportReduceList mExportReduce;
   ExportReduceNewList mExportReduceNew;
   ExportReduceNewResultTypeSet mExportReduceNewResultType;
diff --git a/slang_rs_foreach_lowering.cpp b/slang_rs_foreach_lowering.cpp
index 46e4104..6e9991e 100644
--- a/slang_rs_foreach_lowering.cpp
+++ b/slang_rs_foreach_lowering.cpp
@@ -17,6 +17,7 @@
 #include "slang_rs_foreach_lowering.h"
 
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/Attr.h"
 #include "llvm/Support/raw_ostream.h"
 #include "slang_rs_context.h"
 #include "slang_rs_export_foreach.h"
@@ -129,8 +130,9 @@
   // Verifies that kernel is indeed a "kernel" function.
   *slot = mCtxt->getForEachSlotNumber(kernel);
   if (*slot == -1) {
-    mCtxt->ReportError(CE->getExprLoc(), "%0 applied to non kernel function %1")
-            << funcName << kernel->getName();
+    mCtxt->ReportError(CE->getExprLoc(),
+         "%0 applied to function %1 defined without \"kernel\" attribute")
+         << funcName << kernel->getName();
     return nullptr;
   }
 
@@ -383,7 +385,7 @@
                                            unsigned int targetAPI) {
   slangAssert(FD && FD->hasBody());
 
-  mInsideKernel = RSExportForEach::isRSForEachFunc(targetAPI, FD);
+  mInsideKernel = FD->hasAttr<clang::KernelAttr>();
   VisitStmt(FD->getBody());
 }
 
diff --git a/slang_version.h b/slang_version.h
index 1580f93..74cd09e 100644
--- a/slang_version.h
+++ b/slang_version.h
@@ -53,7 +53,8 @@
 //     SLANG_FEAT_BAR_API_MIN, SLANG_FEAT_BAR_API_MAX
 enum SlangFeatureAPI {
   SLANG_FEATURE_GENERAL_REDUCTION_API = SLANG_N_TARGET_API,
-  SLANG_FEATURE_GENERAL_REDUCTION_HALTER_API = SLANG_DEVELOPMENT_TARGET_API
+  SLANG_FEATURE_GENERAL_REDUCTION_HALTER_API = SLANG_DEVELOPMENT_TARGET_API,
+  SLANG_FEATURE_SINGLE_SOURCE_API = SLANG_N_TARGET_API,
 };
 
 // SlangVersion refers to the released compiler version (for which certain
diff --git a/tests/F_foreach_non_kernel/foreach_non_kernel.rs b/tests/F_foreach_non_kernel/foreach_non_kernel.rs
index 6800e64..dafce3c 100644
--- a/tests/F_foreach_non_kernel/foreach_non_kernel.rs
+++ b/tests/F_foreach_non_kernel/foreach_non_kernel.rs
@@ -2,11 +2,16 @@
 #pragma version(1)
 #pragma rs java_package_name(com.example.foo)
 
+void oldFoo(const int* a, int *b) {
+  *b = *a;
+}
+
 int foo(int a) {
   return a;
 }
 
 void testStart(rs_allocation in, rs_allocation out) {
+  rsForEach(oldFoo, in, out);
   rsForEach(foo, in, out);
 }
 
diff --git a/tests/F_foreach_non_kernel/stderr.txt.expect b/tests/F_foreach_non_kernel/stderr.txt.expect
index 6faada1..5ec8b50 100644
--- a/tests/F_foreach_non_kernel/stderr.txt.expect
+++ b/tests/F_foreach_non_kernel/stderr.txt.expect
@@ -1 +1,2 @@
-foreach_non_kernel.rs:10:3: error: rsForEach applied to non kernel function foo
+foreach_non_kernel.rs:14:3: error: rsForEach applied to function oldFoo defined without "kernel" attribute
+foreach_non_kernel.rs:15:3: error: rsForEach applied to function foo defined without "kernel" attribute
diff --git a/tests/P_foreach/foreach.rs b/tests/P_foreach/foreach.rs
index a87aab9..c23fb81 100644
--- a/tests/P_foreach/foreach.rs
+++ b/tests/P_foreach/foreach.rs
@@ -2,6 +2,10 @@
 #pragma version(1)
 #pragma rs java_package_name(com.example.foo)
 
+void oldFoo(const int* a, int *b) {
+  *b = *a;
+}
+
 int RS_KERNEL foo(int a) {
   return a;
 }