Avoid a footgun in KernelArgsArray
Do not store the address of a `const T&`. Storing the address of a `const T&`
means add_argument(42) does not work, which very counter-intuitive.
PiperOrigin-RevId: 298405422
Change-Id: I769dfa8d7dad92b1e73b1a4f591768b4536cca39
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index d3d386b..514021f 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -76,6 +76,7 @@
#include <vector>
#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
@@ -392,19 +393,22 @@
template <size_t kNumArgs>
class KernelArgsArray : public KernelArgsArrayBase {
public:
- explicit KernelArgsArray()
- : total_shared_memory_bytes_(0),
- number_of_argument_addresses_(0),
- number_of_shared_memory_arguments_(0) {}
+ static constexpr int kMaxGenericArgSize = 8;
// Adds an argument to the list.
- //
- // Note that the address of the argument is stored, so the input must not go
- // out of scope before the instance of this class that calls this method does.
template <typename T>
void add_argument(const T &arg) {
- argument_addresses_[number_of_argument_addresses_] =
- static_cast<const void *>(&arg);
+ static_assert(sizeof(T) <= kMaxGenericArgSize,
+ "Please adjust kMaxGenericArgSize");
+ static_assert(std::is_pod<T>::value, "Only pod types supported!");
+ char *generic_arg_storage =
+ &generic_arguments_[number_of_generic_arguments_++ *
+ kMaxGenericArgSize];
+
+ CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
+ std::memcpy(generic_arg_storage, &arg, sizeof(T));
+
+ argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
++number_of_argument_addresses_;
}
@@ -463,6 +467,10 @@
// Addresses for non-shared-memory arguments.
std::array<const void *, kNumArgs> argument_addresses_;
+ // Storage for arguments of templated type.
+ alignas(kMaxGenericArgSize)
+ std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
+
// Sizes for non-shared-memory arguments.
std::array<size_t, kNumArgs> argument_sizes_;
@@ -473,14 +481,17 @@
std::array<size_t, kNumArgs> shared_memory_indices_;
// Total of all shared memory sizes.
- size_t total_shared_memory_bytes_;
+ size_t total_shared_memory_bytes_ = 0;
// Number of significant entries in argument_addresses_ and argument_sizes_.
- size_t number_of_argument_addresses_;
+ size_t number_of_argument_addresses_ = 0;
// Number of significant entries in shared_memory_bytes_ and
// shared_memory_indices_.
- size_t number_of_shared_memory_arguments_;
+ size_t number_of_shared_memory_arguments_ = 0;
+
+ // The number of generic arguments that have been added to generic_arguments_.
+ size_t number_of_generic_arguments_ = 0;
};
// Typed variant of KernelBase, like a typed device function pointer. See the