| #ifdef NAMEDTENSOR_ENABLED |
| |
| #include <ATen/NamedTensorUtils.h> |
| #include <sstream> |
| |
| namespace at { |
| |
| void internal_set_names_inplace(Tensor& tensor, optional<DimnameList> names) { |
| impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), names); |
| } |
| |
| // Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W'). |
| static std::string toDimnameRepr(const Tensor& tensor) { |
| std::ostringstream os; |
| os << "Tensor"; |
| if (tensor.names() == nullopt) { |
| os << "["; |
| for (auto i = 0; i < tensor.dim(); i++) { |
| if (i != 0) os << ", "; |
| os << "None"; |
| } |
| os << "]"; |
| } else { |
| os << *tensor.names(); |
| } |
| return os.str(); |
| } |
| |
| int64_t dimname_to_position(const Tensor& tensor, Dimname dim) { |
| TORCH_CHECK(dim.type() != NameType::WILDCARD, |
| "Please look up dimensions by name, got: name = None."); |
| TORCH_CHECK(tensor.names().has_value(), |
| "Name ", dim, " not found in ", toDimnameRepr(tensor), "."); |
| const auto names = *tensor.names(); |
| |
| const auto it = std::find_if( |
| names.begin(), names.end(), |
| [&dim](const Dimname& candidate) { return dim.can_refer_to(candidate); }); |
| TORCH_CHECK(it != names.end(), |
| "Name ", dim, " not found in ", toDimnameRepr(tensor), "."); |
| |
| // Check that it can't refer to another dimension |
| const auto dup = std::find_if( |
| it + 1, names.end(), |
| [&dim](const Dimname& candidate) { return dim.can_refer_to(candidate); }); |
| TORCH_CHECK( |
| dup == names.end(), |
| "Name ", dim, " could refer to multiple dimensions in ", |
| toDimnameRepr(tensor), ". Please disambiguate by using a more ", |
| "specific name like ", *it, " or ", dup, "."); |
| return std::distance(names.begin(), it); |
| } |
| |
| static void report_positional_error( |
| const Dimname& name, |
| const Dimname& other_name, |
| DimnameList names, |
| DimnameList other_names) { |
| // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds |
| TORCH_CHECK(false, |
| "Names ", name, " and ", other_name, " do not match positionally ", |
| "from the right in names ", names, " and ", other_names, "."); |
| } |
| |
| static void check_for_misalignment( |
| const Dimname& name, |
| DimnameList names, |
| DimnameList other_names) { |
| if (name.is_wildcard()) { |
| return; |
| } |
| auto it = std::find_if(other_names.begin(), other_names.end(), |
| [&](const Dimname& candidate) { return name.can_refer_to(candidate); }); |
| // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds |
| TORCH_CHECK(it == other_names.end(), |
| "Names ", names, " and ", other_names, " are misaligned: name ", name, |
| " appears in a different position from the right."); |
| } |
| |
| // Assumption: A DimnameList can have no duplicate full names with |
| // the exception of wildcards |
| static std::vector<Dimname> unify_from_right(DimnameList names, DimnameList other_names) { |
| const auto wildcard = Dimname::wildcard(); |
| const auto size = std::max(names.size(), other_names.size()); |
| auto result = std::vector<Dimname>(size, wildcard); |
| |
| auto names_it = names.rbegin(); |
| auto other_it = other_names.rbegin(); |
| auto result_it = result.rbegin(); |
| while (names_it != names.rend() || other_it != other_names.rend()) { |
| // TODO(zou3519): Don't support tagged names for now. They're a little weird. |
| if (names_it->is_tagged() || other_it->is_tagged()) { |
| TORCH_INTERNAL_ASSERT("unify_from_right: NYI: tagged names."); |
| } |
| |
| const auto& name = names_it == names.rend() ? wildcard : *names_it; |
| const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it; |
| |
| // Step 1: Check that the names match |
| const auto maybeName = unify(name, other_name); |
| if (!maybeName) { |
| report_positional_error(name, other_name, names, other_names); |
| } |
| *result_it = *maybeName; |
| |
| // Step 2: Check that the names are not misaligned |
| if (!names_it->is_normal() || !other_it->is_normal()) { |
| // Let: N = max(len(names), len(other_names)) |
| // K = # of special names among names and other_names. |
| // This search (including the outer loop) is O(N*K) but typically # of dims is small. |
| check_for_misalignment(name, names, other_names); |
| check_for_misalignment(other_name, other_names, names); |
| } |
| |
| if (names_it != names.rend()) { |
| ++names_it; |
| } |
| if (other_it != other_names.rend()) { |
| ++other_it; |
| } |
| ++result_it; |
| } |
| return result; |
| } |
| |
| // Assumption: A DimnameList can have no duplicate full names with |
| // the exception of wildcards |
| CAFFE2_API optional<std::vector<Dimname>> |
| unify_from_right(optional<DimnameList> names, optional<DimnameList> other_names) { |
| if (!names && !other_names) { |
| return nullopt; |
| } |
| if (!names) { |
| return other_names.value().vec(); |
| } |
| if (!other_names) { |
| return names.value().vec(); |
| } |
| return unify_from_right(*names, *other_names); |
| } |
| |
| |
| namespace namedinference { |
| |
| optional<std::vector<Dimname>> erase_name(optional<DimnameList> self_names, int64_t dim) { |
| if (self_names == nullopt) { |
| return nullopt; |
| } |
| auto outnames = self_names->vec(); |
| outnames.erase(outnames.begin() + dim); |
| return outnames; |
| } |
| |
| void propagate_names(Tensor& result, const Tensor& src) { |
| at::internal_set_names_inplace(result, src.names()); |
| } |
| |
| void propagate_names(TensorImpl* result, TensorImpl* src) { |
| const auto names = at::impl::internal_get_names(src); |
| at::impl::internal_set_names_inplace(result, names); |
| } |
| |
| } // namespace namedinference |
| } // namespace at |
| #endif |