|  | #include <c10/core/Contiguity.h> | 
|  | #include <c10/core/MemoryFormat.h> | 
|  | #include <c10/core/SymInt.h> | 
|  | #include <c10/core/SymIntArrayRef.h> | 
|  | #include <c10/core/SymbolicShapeMeta.h> | 
|  |  | 
|  | namespace c10 { | 
|  |  | 
|  | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) | 
|  | SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other) | 
|  | // Non-mutables can be accessed outside the mutex | 
|  | : sizes_(other.sizes_), | 
|  | strides_(other.strides_), | 
|  | storage_offset_(other.storage_offset_), | 
|  | strides_valid_(other.strides_valid_) { | 
|  | std::scoped_lock lock(other.mutables_); | 
|  | // These must be copied under lock, so ignore clang-tidy here! | 
|  | // NOLINTBEGIN(cppcoreguidelines-prefer-member-initializer) | 
|  | numel_ = other.numel_; | 
|  | is_contiguous_ = other.is_contiguous_; | 
|  | is_channels_last_contiguous_ = other.is_channels_last_contiguous_; | 
|  | is_channels_last_3d_contiguous_ = other.is_channels_last_3d_contiguous_; | 
|  | is_channels_last_ = other.is_channels_last_; | 
|  | is_channels_last_3d_ = other.is_channels_last_3d_; | 
|  | is_non_overlapping_and_dense_ = other.is_non_overlapping_and_dense_; | 
|  | available_.store(other.available_.load()); | 
|  | // NOLINTEND(cppcoreguidelines-prefer-member-initializer) | 
|  | } | 
|  |  | 
|  | // base, sizes, strides | 
|  | static std::optional< | 
|  | std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>> | 
|  | normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { | 
|  | // Look for a SymNode to dispatch on | 
|  | SymNode base; | 
|  | bool all_hinted = true; | 
|  | // NB: sizes/strides guaranteed to be positive, so only need | 
|  | // is_heap_allocated | 
|  | for (const auto& s : sizes) { | 
|  | if (all_hinted && !s.has_hint()) { | 
|  | all_hinted = false; | 
|  | } | 
|  | if (!base && s.is_heap_allocated()) { | 
|  | base = s.toSymNode(); | 
|  | } | 
|  | } | 
|  | for (const auto& s : strides) { | 
|  | if (all_hinted && !s.has_hint()) { | 
|  | all_hinted = false; | 
|  | } | 
|  | if (!base && s.is_heap_allocated()) { | 
|  | base = s.toSymNode(); | 
|  | } | 
|  | } | 
|  | if (!base || all_hinted) { | 
|  | // Couldn't find.  Tell the caller to do the normal computation | 
|  | // Alternately, if everything is hinted, we want the normal computation | 
|  | // too | 
|  | return std::nullopt; | 
|  | } | 
|  | // Populate the SymNode array | 
|  | std::vector<SymNode> size_nodes; | 
|  | std::vector<SymNode> stride_nodes; | 
|  | size_nodes.reserve(sizes.size()); | 
|  | stride_nodes.reserve(strides.size()); | 
|  | for (const auto& s : sizes) { | 
|  | size_nodes.emplace_back(s.wrap_node(base)); | 
|  | } | 
|  | for (const auto& s : strides) { | 
|  | stride_nodes.emplace_back(s.wrap_node(base)); | 
|  | } | 
|  | return std::make_optional( | 
|  | std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>( | 
|  | std::move(base), std::move(size_nodes), std::move(stride_nodes))); | 
|  | } | 
|  |  | 
|  | // Special treatment because of numel | 
|  | SymBool SymbolicShapeMeta::compute_contiguous() const { | 
|  | if (!strides_valid_) { | 
|  | return false; | 
|  | } | 
|  | c10::SymIntArrayRef sizes(sizes_); | 
|  | c10::SymIntArrayRef strides(strides_); | 
|  | return _compute_contiguous(sizes, strides, numel()); | 
|  | } | 
|  |  | 
|  | // The rest of them | 
|  | #define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ | 
|  | SymBool SymbolicShapeMeta::name() const {                    \ | 
|  | if (!strides_valid_) {                                     \ | 
|  | return false;                                            \ | 
|  | }                                                          \ | 
|  | c10::SymIntArrayRef sizes(sizes_);                         \ | 
|  | c10::SymIntArrayRef strides(strides_);                     \ | 
|  | return fallback(sizes, strides);                           \ | 
|  | } | 
|  |  | 
|  | #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback)        \ | 
|  | SymBool SymbolicShapeMeta::name() const {                     \ | 
|  | if (!strides_valid_) {                                      \ | 
|  | return false;                                             \ | 
|  | }                                                           \ | 
|  | auto n = normalize_sym_sizes_strides(sizes_, strides_);     \ | 
|  | if (n.has_value()) {                                        \ | 
|  | auto [base, size_nodes, stride_nodes] = *n;               \ | 
|  | return SymBool(base->nodeimpl(size_nodes, stride_nodes)); \ | 
|  | } else {                                                    \ | 
|  | c10::SymIntArrayRef sizes(sizes_);                        \ | 
|  | c10::SymIntArrayRef strides(strides_);                    \ | 
|  | return fallback(sizes, strides);                          \ | 
|  | }                                                           \ | 
|  | } | 
|  |  | 
|  | // clang-format off | 
|  | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) | 
|  | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) | 
|  | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) | 
|  | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) | 
|  | DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) | 
|  | // clang-format on | 
|  |  | 
|  | #undef DEFINE_SYMBOOL_COMPUTE | 
|  |  | 
|  | // Glue compute | 
|  | // NB: this logic very intentionally short circuits if possible.  Without | 
|  | // short circuiting, it causes | 
|  | // python test/functorch/test_aotdispatch.py -k | 
|  | // test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run | 
|  | // very slowly. | 
|  |  | 
|  | SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const { | 
|  | init_is_contiguous(); | 
|  | if (definitely_true(is_contiguous(), __FILE__, __LINE__)) { | 
|  | return true; | 
|  | } | 
|  | init_is_channels_last_contiguous(); | 
|  | if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) { | 
|  | return true; | 
|  | } | 
|  | return is_contiguous() | is_channels_last_contiguous() | | 
|  | compute_non_overlapping_and_dense(); | 
|  | } | 
|  |  | 
|  | SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const { | 
|  | init_is_channels_last_contiguous(); | 
|  | if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) { | 
|  | return false; | 
|  | } | 
|  | return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d(); | 
|  | } | 
|  |  | 
|  | SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const { | 
|  | init_is_channels_last_3d_contiguous(); | 
|  | if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) { | 
|  | return false; | 
|  | } | 
|  | return ~is_channels_last_3d_contiguous() & | 
|  | compute_strides_like_channels_last_2d(); | 
|  | } | 
|  |  | 
|  | SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const { | 
|  | if (definitely_true(is_channels_last(), __FILE__, __LINE__)) { | 
|  | return false; | 
|  | } | 
|  | return ~is_channels_last() & compute_strides_like_channels_last_3d(); | 
|  | } | 
|  |  | 
|  | SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const { | 
|  | if (definitely_true(is_contiguous(), __FILE__, __LINE__)) { | 
|  | return true; | 
|  | } | 
|  | if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) { | 
|  | return true; | 
|  | } | 
|  | if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) { | 
|  | return true; | 
|  | } | 
|  | return is_contiguous() | is_channels_last_contiguous() | | 
|  | is_channels_last_3d_contiguous() | compute_non_overlapping_and_dense(); | 
|  | } | 
|  |  | 
|  | SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const { | 
|  | if (definitely_true(is_contiguous(), __FILE__, __LINE__)) { | 
|  | return true; | 
|  | } | 
|  | return is_contiguous() | compute_non_overlapping_and_dense(); | 
|  | } | 
|  |  | 
|  | // NOLINTNEXTLINE(performance-unnecessary-value-param) | 
|  | void SymbolicShapeMeta::set_numel(SymInt val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_numel()) { | 
|  | return; | 
|  | } | 
|  | numel_ = std::move(val); | 
|  | available_.fetch_or(numel_avail); | 
|  | } | 
|  | void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_is_contiguous()) { | 
|  | return; | 
|  | } | 
|  | is_contiguous_ = std::move(val); | 
|  | available_.fetch_or(is_contiguous_avail); | 
|  | } | 
|  | void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_is_channels_last_contiguous()) { | 
|  | return; | 
|  | } | 
|  | is_channels_last_contiguous_ = std::move(val); | 
|  | available_.fetch_or(is_channels_last_contiguous_avail); | 
|  | } | 
|  | void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_is_channels_last_3d_contiguous()) { | 
|  | return; | 
|  | } | 
|  | is_channels_last_3d_contiguous_ = std::move(val); | 
|  | available_.fetch_or(is_channels_last_3d_contiguous_avail); | 
|  | } | 
|  | void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_is_channels_last()) { | 
|  | return; | 
|  | } | 
|  | is_channels_last_ = std::move(val); | 
|  | available_.fetch_or(is_channels_last_avail); | 
|  | } | 
|  | void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_is_channels_last_3d()) { | 
|  | return; | 
|  | } | 
|  | is_channels_last_3d_ = std::move(val); | 
|  | available_.fetch_or(is_channels_last_3d_avail); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::set_is_non_overlapping_and_dense(SymBool val) const { | 
|  | std::scoped_lock lock(mutables_); | 
|  | if (has_is_non_overlapping_and_dense()) { | 
|  | return; | 
|  | } | 
|  | is_non_overlapping_and_dense_ = std::move(val); | 
|  | available_.fetch_or(is_non_overlapping_and_dense_avail); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_numel() const { | 
|  | set_numel(multiply_integers(sizes_)); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_is_contiguous() const { | 
|  | set_is_contiguous(compute_contiguous()); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_is_channels_last_contiguous() const { | 
|  | set_is_channels_last_contiguous([&] { | 
|  | switch (dim()) { | 
|  | case 5: | 
|  | case 4: { | 
|  | return compute_channels_last_contiguous_2d(); | 
|  | } | 
|  | default: | 
|  | return SymBool{false}; | 
|  | } | 
|  | }()); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_is_channels_last_3d_contiguous() const { | 
|  | set_is_channels_last_3d_contiguous([&] { | 
|  | switch (dim()) { | 
|  | case 5: | 
|  | return compute_channels_last_contiguous_3d_dim5(); | 
|  | default: | 
|  | return SymBool{false}; | 
|  | } | 
|  | }()); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_is_channels_last() const { | 
|  | set_is_channels_last([&] { | 
|  | switch (dim()) { | 
|  | case 5: | 
|  | return compute_channels_last_2d_dim5(); | 
|  | case 4: | 
|  | return compute_strides_like_channels_last_2d(); | 
|  | default: | 
|  | return SymBool{false}; | 
|  | } | 
|  | }()); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_is_channels_last_3d() const { | 
|  | set_is_channels_last_3d([&] { | 
|  | switch (dim()) { | 
|  | case 5: | 
|  | return compute_channels_last_3d_dim5(); | 
|  | default: | 
|  | return SymBool{false}; | 
|  | } | 
|  | }()); | 
|  | } | 
|  |  | 
|  | void SymbolicShapeMeta::init_is_non_overlapping_and_dense() const { | 
|  | set_is_non_overlapping_and_dense([&] { | 
|  | switch (dim()) { | 
|  | case 5: | 
|  | return compute_is_non_overlapping_and_dense_dim5(); | 
|  | case 4: | 
|  | return compute_is_non_overlapping_and_dense_dim4(); | 
|  | default: | 
|  | return compute_is_non_overlapping_and_dense_anydim(); | 
|  | } | 
|  | }()); | 
|  | } | 
|  |  | 
|  | } // namespace c10 |