/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Suite of datatypes to represent data-parallel kernel objects (code entities).
// Kernel is the untyped variant, whereas TypedKernel takes a type signature
// to do some template-based helper generation and give compile-time type
// checking for kernel launch parameters.
//
// Users typically don't see KernelBase, they see typed kernels, analogous to a
// typed function pointer. TypedKernels express their argument types via
// template parameters like so:
//
//  TypedKernel<DeviceMemory<int>*, int>
//
// Which expresses a data parallel kernel signature for:
//
//  void(int*, int);
//
// And for a const memory region:
//
//  TypedKernel<const DeviceMemory<int>&, int>
//
// Corresponds to a data parallel kernel signature for:
//
//  void(const int*, int)
//
// Note that kernels always have a void return type, so results typically must
// be memcpy'ied from device memory to the host.
//
// Also note that a scalar integer residing in device memory and an array of
// integers residing in device memory have the same signature: DeviceMemory<T>.
// However, in the future, checks may be added for additional safety that arrays
// of minimum sizes are passed when those minimum sizes are contractually
// expected by the kernel.
//
// For user-defined types whose definitions are appropriately shared between the
// host code doing the launching and the kernel code being launched, the user
// defined types are similarly permitted to be expressed as residing in device
// memory:
//
//  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
//
// And, when the alignment and padding are agreed upon, POD types will also be
// able to be passed by value; for example, it is a common idiom to specify a
// bunch of options simultaneously with a structure:
//
//  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
//
// Which corresponds to a data parallel kernel signature like:
//
//  void(MyOptionsStructurePassedByValue value, float *result);
//
// Users typically won't need to type out the TypedKernel signature in full, it
// will be typedef'd by automatically generated code; for example, see
// stream_executor::executor_sample::VecReduceAddKernel.

#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_

#include <array>
#include <memory>
#include <tuple>
#include <type_traits>
#include <vector>

#include "absl/strings/string_view.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/platform/port.h"

namespace stream_executor {

class DeviceMemoryBase;
template <typename ElemT>
class DeviceMemory;
class StreamExecutor;

namespace internal {
class KernelInterface;
}  // namespace internal

// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
// registers allocated, shared memory used, etc.
// Not all platforms support reporting of all information, so each accessor
// returns false if the associated field is not populated in the underlying
// platform.
class KernelMetadata {
 public:
  KernelMetadata()
      : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}

  // Returns the number of registers used per thread executing this kernel.
  bool registers_per_thread(int *registers_per_thread) const;

  // Sets the number of registers used per thread executing this kernel.
  void set_registers_per_thread(int registers_per_thread);

  // Returns the amount of [static] shared memory used per block executing this
  // kernel. Note that dynamic shared memory allocations are not (and can not)
  // be reported here (since they're not specified until kernel launch time).
  bool shared_memory_bytes(int *shared_memory_bytes) const;

  // Sets the amount of [static] shared memory used per block executing this
  // kernel.
  void set_shared_memory_bytes(int shared_memory_bytes);

 private:
  // Holds the value returned by registers_per_thread above.
  bool has_registers_per_thread_;
  int registers_per_thread_;

  // Holds the value returned by shared_memory_bytes above.
  bool has_shared_memory_bytes_;
  int64 shared_memory_bytes_;
};

// A data-parallel kernel (code entity) for launching via the StreamExecutor,
// analogous to a void* device function pointer. See TypedKernel for the typed
// variant.
//
// Thread-compatible.
class KernelBase {
 public:
  KernelBase(KernelBase &&from);

  // Constructs an "empty" (not-yet-loaded) kernel instance.
  //
  // parent is the StreamExecutor that will be responsible for loading the
  // implementation of this kernel. It must not be null.
  explicit KernelBase(StreamExecutor *parent);

  // Test-only constructor that can take a mock KernelInterface implementation.
  KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);

  // Releases resources associated with the kernel instance (i.e.
  // platform-specific implementation).
  ~KernelBase();

  // Returns the number of parameters that this kernel accepts. (Arity refers to
  // nullary, unary, ...).
  unsigned Arity() const;

  // Returns the StreamExecutor that represents the platform this kernel
  // executes upon.
  StreamExecutor *parent() const { return parent_; }

  // Returns a const pointer to the (opaque) platform-dependent implementation.
  const internal::KernelInterface *implementation() const {
    return implementation_.get();
  }

  // Returns a non-const pointer to the (opaque) platform-dependent
  // implementation.
  internal::KernelInterface *implementation() { return implementation_.get(); }

  void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }

  const KernelMetadata &metadata() const { return metadata_; }

  // Sets the preferred cache configuration for a kernel. This is just a
  // suggestion to the runtime, and may not be honored during execution.
  void SetPreferredCacheConfig(KernelCacheConfig config);

  // Gets the preferred cache configuration for a kernel.
  KernelCacheConfig GetPreferredCacheConfig() const;

  void set_name(absl::string_view name);
  const string &name() const { return name_; }
  const string &demangled_name() const { return demangled_name_; }

 private:
  // The StreamExecutor that loads this kernel object.
  StreamExecutor *parent_;

  // Implementation delegated to for platform-specific functionality.
  std::unique_ptr<internal::KernelInterface> implementation_;

  string name_;
  string demangled_name_;

  KernelMetadata metadata_;

  SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
};

// Whether T is a DeviceMemory-family pointer.
template <typename T>
struct IsDeviceMemoryPointer {
  static constexpr bool value = false;
};

template <typename U>
struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
  static constexpr bool value = true;
};

template <>
struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
  static constexpr bool value = true;
};

// Whether T is a DeviceMemory-family value-like thing (which includes a
// reference). This trait is useful because we pack values in the same manner as
// references.
template <typename T>
struct IsDeviceMemoryValueLike {
  static constexpr bool value = false;
};

template <typename U>
struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
  static constexpr bool value = true;
};

// We need to treat SharedDeviceMemory types differently than other DeviceMemory
// types (since they maintain no allocations), hence these specializations.
template <typename U>
struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
  static constexpr bool value = false;
};

template <>
struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
  static constexpr bool value = true;
};

template <typename U>
struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
  static constexpr bool value = true;
};

template <typename U>
struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
  static constexpr bool value = false;
};

template <>
struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
  static constexpr bool value = true;
};

template <typename U>
struct IsSharedDeviceMemory {
  static constexpr bool value = false;
};

template <typename U>
struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
  static constexpr bool value = true;
};

template <typename U>
struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
  static constexpr bool value = true;
};

// Basic data about a kernel argument.
struct KernelArg {
  bool is_shared;
  const void *address;
  size_t size;
};

// An iterator for traversing all the arguments of a KernelArgsArray.
class KernelArgIterator {
 public:
  KernelArgIterator(int number_of_argument_addresses,
                    int number_of_shared_memory_arguments,
                    const void *const *arg_addresses_data,
                    const size_t *arg_sizes_data,
                    const size_t *shmem_bytes_data,
                    const size_t *shmem_indices_data)
      : arg_index_(0),
        number_of_arguments_(number_of_argument_addresses +
                             number_of_shared_memory_arguments),
        arg_address_iter_(arg_addresses_data),
        arg_size_iter_(arg_sizes_data),
        shmem_bytes_iter_(shmem_bytes_data),
        shmem_indices_iter_(shmem_indices_data),
        shmem_indices_end_(shmem_indices_data +
                           number_of_shared_memory_arguments) {}

  // Returns true if another argument is present in the iterator.
  bool has_next() { return arg_index_ < number_of_arguments_; }

  // Returns the next argument in the iterator.
  //
  // Returns a default-constructed KernelArg if there is no next argument.
  KernelArg next() {
    KernelArg result = {};
    if (!has_next()) {
      return result;
    } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
               (arg_index_ == *shmem_indices_iter_)) {
      result.is_shared = true;
      result.address = nullptr;
      result.size = *shmem_bytes_iter_;
      ++shmem_indices_iter_;
      ++shmem_bytes_iter_;
    } else {
      result.is_shared = false;
      result.address = *arg_address_iter_;
      result.size = *arg_size_iter_;
      ++arg_address_iter_;
      ++arg_size_iter_;
    }
    ++arg_index_;
    return result;
  }

 private:
  size_t arg_index_;
  size_t number_of_arguments_;
  const void *const *arg_address_iter_;
  const size_t *arg_size_iter_;
  const size_t *shmem_bytes_iter_;
  const size_t *shmem_indices_iter_;
  const size_t *const shmem_indices_end_;
};

// Base class for KernelArgsArray.
//
// Supports all the getter methods that do not depend on the compile-time number
// of arguments template parameter.
//
// This class exists as a way to pass kernel arguments to
// StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
// be templated to accept any KernelArgsArray type, therefore a reference to
// this base type is passed instead.
//
// Performance is not a concern here because each of these methods will be
// called at most once per kernel launch. Past performance concerns with
// KernelArgsArray have been in reference to the argument packing routines which
// are called once per kernel argument. Those packing routines are now handled
// by the templated KernelArgsArray subclass of this class where they can take
// advantage of compile-time knowledge of the number of arguments in order to be
// very efficient.
class KernelArgsArrayBase {
 public:
  virtual ~KernelArgsArrayBase() = default;

  // Gets the number of arguments added so far, including shared memory
  // arguments.
  virtual size_t number_of_arguments() const = 0;

  // Gets the total number of shared memory bytes added so far.
  virtual uint64 number_of_shared_bytes() const = 0;

  // Gets the list of argument addresses.
  virtual port::ArraySlice<const void *> argument_addresses() const = 0;

  // Gets an iterator to the arguments in the array.
  virtual KernelArgIterator arg_iterator() const = 0;
};

// A list of arguments for a kernel call.
//
// The template parameter kNumArgs is the maximum number of arguments which can
// be stored in the list.
//
// Contains a list of addresses for non-shared-memory arguments and a list of
// sizes for shared-memory arguments. Since the shared-memory arguments may be
// interspersed with the non-shared-memory arguments, it also stores a list of
// the indices at which the shared-memory arguments appeared.
//
// For example, if the argument address list contains {a, b, c, d, e}, the
// shared-memory arguments list contains the sizes of {A, B, C}, and the
// shared-memory indices list contains {0, 3, 5}, then the original list of
// arguments was {A, a, b, B, c, C, d, e}.
//
// This way of storing the arguments makes CUDA kernel calls efficient because
// they only require the argument address list and the total number of shared
// bytes, but it also makes it possible for OpenCL kernel calls because they
// depend on the location of each shared-memory argument and its size.
//
// Note that the code for adding arguments has been identified as a performance
// hotspot in some real-world applications so this structure has been optimized
// for the performance of argument adding.
template <size_t kNumArgs>
class KernelArgsArray : public KernelArgsArrayBase {
 public:
  explicit KernelArgsArray()
      : total_shared_memory_bytes_(0),
        number_of_argument_addresses_(0),
        number_of_shared_memory_arguments_(0) {}

  // Adds an argument to the list.
  //
  // Note that the address of the argument is stored, so the input must not go
  // out of scope before the instance of this class that calls this method does.
  template <typename T>
  void add_argument(const T &arg) {
    argument_addresses_[number_of_argument_addresses_] =
        static_cast<const void *>(&arg);
    argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
    ++number_of_argument_addresses_;
  }

  // Adds a device memory argument to the list.
  void add_device_memory_argument(const DeviceMemoryBase &arg) {
    const void **copy_ptr =
        &device_memory_opaque_pointers_[number_of_argument_addresses_];
    *copy_ptr = arg.opaque();
    argument_addresses_[number_of_argument_addresses_] = copy_ptr;
    argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
    ++number_of_argument_addresses_;
  }

  // Adds a shared memory argument to the list.
  //
  // The only significant information about a shared argument is its size, so
  // that is the only parameter in this function.
  void add_shared_bytes(size_t number_of_bytes) {
    shared_memory_indices_[number_of_shared_memory_arguments_] =
        number_of_argument_addresses_ + number_of_shared_memory_arguments_;
    shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
    ++number_of_shared_memory_arguments_;
    total_shared_memory_bytes_ += number_of_bytes;
  }

  // Gets the number of arguments added so far, including shared memory
  // arguments.
  size_t number_of_arguments() const override {
    return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
  }

  // Gets the total number of shared memory bytes added so far.
  uint64 number_of_shared_bytes() const override {
    return total_shared_memory_bytes_;
  }

  // Gets the list of argument addresses.
  port::ArraySlice<const void *> argument_addresses() const override {
    return port::ArraySlice<const void *>(argument_addresses_.data(),
                                          number_of_argument_addresses_);
  }

  // Gets an iterator to the arguments in the array.
  KernelArgIterator arg_iterator() const override {
    return KernelArgIterator(
        number_of_argument_addresses_, number_of_shared_memory_arguments_,
        argument_addresses_.data(), argument_sizes_.data(),
        shared_memory_bytes_.data(), shared_memory_indices_.data());
  }

 private:
  // A place to store copies of opaque pointers from device memory arguments.
  std::array<const void *, kNumArgs> device_memory_opaque_pointers_;

  // Addresses for non-shared-memory arguments.
  std::array<const void *, kNumArgs> argument_addresses_;

  // Sizes for non-shared-memory arguments.
  std::array<size_t, kNumArgs> argument_sizes_;

  // Size in bytes for each shared memory argument.
  std::array<size_t, kNumArgs> shared_memory_bytes_;

  // Indices in the arguments array for shared memory arguments.
  std::array<size_t, kNumArgs> shared_memory_indices_;

  // Total of all shared memory sizes.
  size_t total_shared_memory_bytes_;

  // Number of significant entries in argument_addresses_ and argument_sizes_.
  size_t number_of_argument_addresses_;

  // Number of significant entries in shared_memory_bytes_ and
  // shared_memory_indices_.
  size_t number_of_shared_memory_arguments_;
};

// Typed variant of KernelBase, like a typed device function pointer. See the
// file comment for details and example usage.
//
// This class contains template metaprogramming magic to type check the
// parameters passed to a kernel launch are acceptable, and subsequently pack
// them into a form which can be used by the StreamExecutorInterface
// implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
// sizes as kernel arguments.)
//
// Thread-compatible.
template <typename... Params>
class TypedKernel : public KernelBase {
 public:
  static constexpr size_t kNumberOfParameters = sizeof...(Params);

  // Delegates to KernelBase::KernelBase(), see that constructor.
  explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}

  // Test-only constructor that can take a mock KernelInterface implementation.
  // Takes ownership of implementation, it should not be null.
  TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
      : KernelBase(parent, implementation) {}

 private:
  // Stream needs access to the specific parameter-packing functionality that
  // the TypedKernel provides for its corresponding type signature (and no other
  // type signatures).
  friend class Stream;

  // This is the main entry point into the magic. Packs the parameters (which
  // must type check against the class template) into the args and sizes
  // arrays.
  //
  // Const refs are taken as parameters on all of the handlers to avoid
  // implicit type promotion of integers.
  //
  // WARNING: as a performance optimization this method may store pointers to
  // some of the input parameters in the kernel args structure, so any params
  // passed into this method must live at least as long as the kernel args
  // structure.
  void PackParams(KernelArgsArray<kNumberOfParameters> *args,
                  Params &... params) const {
    PackOneParamFromList(args, params...);
  }

  template <typename T, typename... RestOfParams>
  void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args,
                            const T &arg, const RestOfParams &... rest) const {
    PackOneParam(args, arg);
    PackOneParamFromList(args, rest...);
  }

  // Base case for variadic template expansion - nothing to do!
  void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args) const {}

  // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
  // The enable_if<> is for excluding DeviceMemoryBase args, which have a
  // separate implementation below.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, const T &arg,
      typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
                              !IsDeviceMemoryPointer<T>::value &&
                              !IsSharedDeviceMemory<T>::value>::type * =
          nullptr) const {
    static_assert(!std::is_pointer<T>::value,
                  "cannot pass raw pointer to the device");
    static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
                  "cannot pass device memory as a normal value");
    args->add_argument(arg);
  }

  // DeviceMemoryBase family reference override.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, const T &arg,
      typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
          nullptr) const {
    args->add_device_memory_argument(arg);
  }

  // DeviceMemoryBase family pointer override.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, T arg,
      typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
          nullptr) const {
    DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
    args->add_device_memory_argument(*ptr);
  }

  // Dynamic shared device memory has a size, but no associated allocation on
  // the host; internally, the device will allocate storage.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, T arg,
      typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
          nullptr) const {
    args->add_shared_bytes(arg.size());
  }

  SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
};

// Template metaprogramming helper type that helps us produce better error
// messages at compile time when the are mismatches between the parameter
// type list and the argument type list.
template <typename ParamTuple, typename ArgTuple>
struct KernelInvocationChecker {
  // Whether the parameter tuple and argument tuple match in length.
  static constexpr bool kLengthMatches =
      std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;

  // The (matching) length of the parameters and arguments type lists.
  static constexpr int kTupleLength =
      static_cast<int>(std::tuple_size<ArgTuple>::value);

  // Helper trait to say whether the parameter wants a DeviceMemory-reference
  // compatible type. This is for inexact type matches, so that it doesn't have
  // to be precisely a const DeviceMemory<T>&, but can also be a value that
  // represents the same.
  template <typename ParamType, typename ArgType>
  struct IsCompatibleDeviceMemoryRef {
    static constexpr bool value = false;
  };

  // See type trait definition above.
  template <typename U>
  struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
    static constexpr bool value = true;
  };

  // See type trait definition above.
  template <typename U>
  struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
                                     SharedDeviceMemory<U>> {
    static constexpr bool value = true;
  };

  // Returns whether ParamT and ArgT are compatible for data parallel kernel
  // parameter packing without any assert functionality.
  template <typename ParamT, typename ArgT>
  static constexpr bool CompatibleNoAssert() {
    return std::is_same<typename std::remove_const<ParamT>::type,
                        ArgT>::value ||
           IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
  }

  // Checks whether ParamT and ArgT are compatible for data parallel kernel
  // parameter packing. kArgumentNumber is unused, it just for error display.
  //
  // NOTE: if you encounter an error here, you can see the mismatch by looking
  // at the end of the last error message, which will be of the form:
  //
  //    ...::Compatible<const stream_executor::DeviceMemory<OneThing> &,
  //                    stream_executor::DeviceMemory<AnotherThing>, true,
  //                    0>'
  //    requested here
  //
  // This means that the 0th argument you passed to the kernel invocation should
  // have been DeviceMemory<OneThing> but was observed to be
  // DeviceMemory<AnotherThing>.
  template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
            int kArgumentNumber>
  static constexpr bool Compatible() {
    static_assert(
        kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
        "parameter type (LHS) is not compatible with argument type (RHS)");
    return CompatibleNoAssert<ParamT, ArgT>();
  }

  // Checks the parameter/argument match at kArgumentNumber for an out of bounds
  // argument number.
  //
  // This is the base case: we've run out of argument to check, so we're all
  // good.
  template <int kArgumentNumber, bool kShouldStaticAssert>
  static constexpr bool CheckParam(
      typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
    return true;
  }

  // Checks the parameter/argument match at kArgumentNumber.
  // kShouldStaticAssert determines whether to assert out on a mismatch, or just
  // yield the constexpr boolean value.
  template <int kArgumentNumber, bool kShouldStaticAssert>
  static constexpr bool CheckParam(
      typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
    typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
        ParamT;
    typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
    return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
           CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
  }

  // Checks the parameters/arguments for match, but doesn't static assert out.
  // This is useful for testing/inspecting whether a set of parameters match in
  // things like tests.
  static constexpr bool CheckAllNoStaticAssert() {
    return kLengthMatches && CheckParam<kTupleLength - 1, false>();
  }

  // Checks the parameters and static asserts out with a helpful error message
  // (and useful template parameters in the instantiation stack) if there is an
  // error.
  static constexpr bool CheckAllStaticAssert() {
    static_assert(kLengthMatches,
                  "argument length mismatched against typed kernel parameters");
    return kLengthMatches && CheckParam<kTupleLength - 1, true>();
  }
};

// This is a convenience type for checking whether a typed kernel matches
// against a type list.
template <typename KernelT, typename... Params>
struct KernelParamsOk {
  static constexpr bool kResult = false;
};

// See above.
template <typename... Params, typename... Args>
struct KernelParamsOk<TypedKernel<Params...>, Args...> {
  static constexpr bool kResult = KernelInvocationChecker<
      std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
};

}  // namespace stream_executor

#endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
