Improve boxed dispatch performance (#33313)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33313
Instead of just remembering the number of arguments and iterating over the stack,
the DispatchKeyExtractor now remembers the exact locations of the dispatch relevant arguments
(i.e. Tensor arguments) and only looks at those.
ghstack-source-id: 101908386
Test Plan: unit tests, benchmarks
Differential Revision: D19748549
fbshipit-source-id: b5b9ff2233b3507e0b600460f422912cfa9e3f0f
diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp
index 49849f0..adde912 100644
--- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp
+++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp
@@ -14,12 +14,19 @@
std::string DispatchKeyExtractor::dumpState() const {
std::ostringstream oss;
- oss << num_args_ << " " << operatorHasKernelForBackend_ << "\n";
+ for (size_t i=0; i < c10::utils::bitset::NUM_BITS(); ++i) {
+ if (dispatch_arg_indices_reverse_.get(i)) {
+ oss << "1";
+ } else {
+ oss << "0";
+ }
+ }
+ oss << " " << operatorHasKernelForBackend_ << "\n";
return oss.str();
}
void DispatchKeyExtractor::checkInvariants(const FunctionSchema& schema) const {
- TORCH_INTERNAL_ASSERT(schema.arguments().size() == num_args_);
+ TORCH_INTERNAL_ASSERT(makeBitsetForDispatchArgs(schema) == dispatch_arg_indices_reverse_);
}
} // namespace c10
diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
index 6644bb0..ee911b7 100644
--- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
+++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
@@ -3,6 +3,7 @@
#include <cstdint>
#include <ATen/core/function_schema.h>
#include <ATen/core/jit_type.h>
+#include <c10/util/Bitset.h>
#include <c10/core/DispatchKeySet.h>
#include <ATen/core/Variadic.h>
#include <ATen/core/stack.h>
@@ -101,19 +102,19 @@
struct CAFFE2_API DispatchKeyExtractor final {
public:
static DispatchKeyExtractor make(const FunctionSchema& schema) {
- return DispatchKeyExtractor(schema.arguments().size());
+ return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
}
static DispatchKeyExtractor makeUninitialized() {
- return DispatchKeyExtractor(0);
+ return DispatchKeyExtractor(c10::utils::bitset());
}
void registerSchema(const FunctionSchema& schema) {
- TORCH_INTERNAL_ASSERT(num_args_ == 0);
- num_args_ = schema.arguments().size();
+ TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
+ dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
}
void deregisterSchema() {
- num_args_ = 0;
+ dispatch_arg_indices_reverse_ = c10::utils::bitset();
}
DispatchKey getDispatchKeyBoxed(DispatchKeySet backendsWithoutFallthrough, const torch::jit::Stack* stack) const {
@@ -121,7 +122,8 @@
// but boxed doesn't yet. See https://github.com/pytorch/pytorch/issues/26428
DispatchKeySet ks;
- for (const auto& ivalue : torch::jit::last(*stack, num_args_)) {
+ dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) {
+ const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
if (C10_LIKELY(ivalue.isTensor())) {
// NB: Take care not to introduce a refcount bump (there's
// no safe toTensorRef method, alas)
@@ -131,7 +133,7 @@
ks = ks | tensor.key_set();
}
}
- }
+ });
return dispatchKeySetToDispatchKey_(backendsWithoutFallthrough, DispatchKeySet::FULL, ks);
}
@@ -149,6 +151,19 @@
void checkInvariants(const FunctionSchema& schema) const;
private:
+ static c10::utils::bitset makeBitsetForDispatchArgs(const FunctionSchema& schema) {
+ TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
+ "The function schema has ", schema.arguments().size(),
+ " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
+ c10::utils::bitset dispatch_arg_indices_reverse;
+ for (size_t index = 0; index < schema.arguments().size(); ++index) {
+ if (schema.arguments()[index].type()->isSubtypeOf(TensorType::get()) || schema.arguments()[index].type()->isSubtypeOf(ListType::ofTensors())) {
+ dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
+ }
+ }
+ return dispatch_arg_indices_reverse;
+ }
+
// NB: If there is no valid dispatch key, this will return Undefined
DispatchKey dispatchKeySetToDispatchKey_(
DispatchKeySet backendsWithoutFallthrough,
@@ -176,16 +191,19 @@
& eligibleKeys);
}
- explicit DispatchKeyExtractor(size_t num_args)
- : num_args_(num_args)
+ explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
+ : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
, operatorHasKernelForBackend_() {}
- // this is caching the index so we don't have to parse the schema inputs
- // again and again for each dispatcher lookup.
- // num_args_ is allowed to be zero; that just means you must do the
+ // this is a bitset that has ones for each argument index which has to be
+ // considered for dispatch. This avoids having to iterate over the stack
+ // to find all the tensors. The bits are stored in reverse order, i.e.
+ // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
+ // the top of the stack (i.e. the i-th last argument of the function)
+ // is relevant for dispatch.
+ // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the
// fallthrough
- // TODO: a potential optimization is to store a bitfield of arg locations,
- size_t num_args_;
+ c10::utils::bitset dispatch_arg_indices_reverse_;
// Set of backends for which the operator has explicitly registered a kernel.
DispatchKeySet operatorHasKernelForBackend_;
diff --git a/c10/test/util/Bitset_test.cpp b/c10/test/util/Bitset_test.cpp
new file mode 100644
index 0000000..12d5fcf
--- /dev/null
+++ b/c10/test/util/Bitset_test.cpp
@@ -0,0 +1,141 @@
+#include <gtest/gtest.h>
+
+#include <c10/util/Bitset.h>
+
+using c10::utils::bitset;
+
+TEST(BitsetTest, givenEmptyBitset_whenGettingBit_thenIsZero) {
+ bitset b;
+ for (size_t i = 0; i < bitset::NUM_BITS(); ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+}
+
+TEST(BitsetTest, givenEmptyBitset_whenUnsettingBit_thenIsZero) {
+ bitset b;
+ b.unset(4);
+ for (size_t i = 0; i < bitset::NUM_BITS(); ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+}
+
+TEST(BitsetTest, givenEmptyBitset_whenSettingAndUnsettingBit_thenIsZero) {
+ bitset b;
+ b.set(4);
+ b.unset(4);
+ for (size_t i = 0; i < bitset::NUM_BITS(); ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+}
+
+TEST(BitsetTest, givenEmptyBitset_whenSettingBit_thenIsSet) {
+ bitset b;
+ b.set(6);
+ EXPECT_TRUE(b.get(6));
+}
+
+TEST(BitsetTest, givenEmptyBitset_whenSettingBit_thenOthersStayUnset) {
+ bitset b;
+ b.set(6);
+ for (size_t i = 0; i < 6; ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+ for (size_t i = 7; i < bitset::NUM_BITS(); ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+}
+
+TEST(BitsetTest, givenNonemptyBitset_whenSettingBit_thenIsSet) {
+ bitset b;
+ b.set(6);
+ b.set(30);
+ EXPECT_TRUE(b.get(30));
+}
+
+TEST(BitsetTest, givenNonemptyBitset_whenSettingBit_thenOthersStayAtOldValue) {
+ bitset b;
+ b.set(6);
+ b.set(30);
+ for (size_t i = 0; i < 6; ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+ for (size_t i = 7; i < 30; ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+ for (size_t i = 31; i < bitset::NUM_BITS(); ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+}
+
+TEST(BitsetTest, givenNonemptyBitset_whenUnsettingBit_thenIsUnset) {
+ bitset b;
+ b.set(6);
+ b.set(30);
+ b.unset(6);
+ EXPECT_FALSE(b.get(6));
+}
+
+TEST(
+ BitsetTest,
+ givenNonemptyBitset_whenUnsettingBit_thenOthersStayAtOldValue) {
+ bitset b;
+ b.set(6);
+ b.set(30);
+ b.unset(6);
+ for (size_t i = 0; i < 30; ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+ EXPECT_TRUE(b.get(30));
+ for (size_t i = 31; i < bitset::NUM_BITS(); ++i) {
+ EXPECT_FALSE(b.get(i));
+ }
+}
+
+struct IndexCallbackMock final {
+ std::vector<size_t> called_for_indices;
+
+ void operator()(size_t index) {
+ called_for_indices.push_back(index);
+ }
+
+ void expect_was_called_for_indices(std::vector<size_t> expected_indices) {
+ EXPECT_EQ(expected_indices.size(), called_for_indices.size());
+ for (size_t i = 0; i < expected_indices.size(); ++i) {
+ EXPECT_EQ(expected_indices[i], called_for_indices[i]);
+ }
+ }
+};
+
+TEST(BitsetTest, givenEmptyBitset_whenCallingForEachBit_thenDoesntCall) {
+ IndexCallbackMock callback;
+ bitset b;
+ b.for_each_set_bit(callback);
+ callback.expect_was_called_for_indices({});
+}
+
+TEST(
+ BitsetTest,
+ givenBitsetWithOneBitSet_whenCallingForEachBit_thenCallsForEachBit) {
+ IndexCallbackMock callback;
+ bitset b;
+ b.set(5);
+ b.for_each_set_bit(callback);
+ callback.expect_was_called_for_indices({5});
+}
+
+TEST(
+ BitsetTest,
+ givenBitsetWithMultipleBitsSet_whenCallingForEachBit_thenCallsForEachBit) {
+ IndexCallbackMock callback;
+ bitset b;
+ b.set(5);
+ b.set(2);
+ b.set(25);
+ b.set(32);
+ b.set(50);
+ b.set(0);
+ b.unset(25);
+ b.set(10);
+ b.for_each_set_bit(callback);
+ callback.expect_was_called_for_indices({0, 2, 5, 10, 32, 50});
+}
diff --git a/c10/util/Bitset.h b/c10/util/Bitset.h
new file mode 100644
index 0000000..797dfa9
--- /dev/null
+++ b/c10/util/Bitset.h
@@ -0,0 +1,102 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <c10/util/C++17.h>
+#include <c10/util/Optional.h>
+#include <iostream>
+#if defined(_MSC_VER)
+#include <intrin.h>
+#endif
+
+namespace c10 {
+namespace utils {
+
+/**
+ * This is a simple bitset class with sizeof(long long int) bits.
+ * You can set bits, unset bits, query bits by index,
+ * and query for the first set bit.
+ * Before using this class, please also take a look at std::bitset,
+ * which has more functionality and is more generic. It is probably
+ * a better fit for your use case. The sole reason for c10::utils::bitset
+ * to exist is that std::bitset misses a find_first_set() method.
+ */
+struct bitset final {
+private:
+ #if defined(_MSC_VER)
+ // MSVCs _BitScanForward64 expects int64_t
+ using bitset_type = int64_t;
+ #else
+ // POSIX ffsll expects long long int
+ using bitset_type = long long int;
+ #endif
+public:
+ static constexpr size_t NUM_BITS() {
+ return 8 * sizeof(bitset_type);
+ }
+
+ constexpr bitset() noexcept : bitset_(0) {}
+ constexpr bitset(const bitset&) noexcept = default;
+ constexpr bitset(bitset&&) noexcept = default;
+ constexpr bitset& operator=(const bitset&) noexcept = default;
+ constexpr bitset& operator=(bitset&&) noexcept = default;
+
+ constexpr void set(size_t index) noexcept {
+ bitset_ |= (static_cast<long long int>(1) << index);
+ }
+
+ constexpr void unset(size_t index) noexcept {
+ bitset_ &= ~(static_cast<long long int>(1) << index);
+ }
+
+ constexpr bool get(size_t index) const noexcept {
+ return bitset_ & (static_cast<long long int>(1) << index);
+ }
+
+ constexpr bool is_entirely_unset() const noexcept {
+ return 0 == bitset_;
+ }
+
+ // Call the given functor with the index of each bit that is set
+ template <class Func>
+ void for_each_set_bit(Func&& func) const {
+ bitset cur = *this;
+ size_t index = cur.find_first_set();
+ while (0 != index) {
+ // -1 because find_first_set() is not one-indiced.
+ index -= 1;
+ func(index);
+ cur.unset(index);
+ index = cur.find_first_set();
+ }
+ }
+
+private:
+ // Return the index of the first set bit. The returned index is one-indiced
+ // (i.e. if the very first bit is set, this function returns '1'), and a return
+ // of '0' means that there was no bit set.
+ size_t find_first_set() const {
+ #if defined(_MSC_VER)
+ unsigned long result;
+ bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
+ if (!has_bits_set) {
+ return 0;
+ }
+ return result + 1;
+ #else
+ return __builtin_ffsll(bitset_);
+ #endif
+ }
+
+ friend bool operator==(bitset lhs, bitset rhs) noexcept {
+ return lhs.bitset_ == rhs.bitset_;
+ }
+
+ bitset_type bitset_;
+};
+
+inline bool operator!=(bitset lhs, bitset rhs) noexcept {
+ return !(lhs == rhs);
+}
+
+} // namespace utils
+} // namespace c10