| // |
| // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com |
| // |
| // Distributed under the Boost Software License, Version 1.0. (See |
| // accompanying file LICENSE_1_0.txt or copy at |
| // http://www.boost.org/LICENSE_1_0.txt) |
| // |
| // The authors gratefully acknowledge the support of |
| // Fraunhofer IOSB, Ettlingen, Germany |
| // |
| |
| |
| #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP |
| #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP |
| |
| #include <algorithm> |
| #include <initializer_list> |
| #include <limits> |
| #include <numeric> |
| #include <stdexcept> |
| #include <vector> |
| |
| #include <cassert> |
| |
| namespace boost { |
| namespace numeric { |
| namespace ublas { |
| |
| |
| /** @brief Template class for storing tensor extents with runtime variable size. |
| * |
| * Proxy template class of std::vector<int_type>. |
| * |
| */ |
| template<class int_type> |
| class basic_extents |
| { |
| static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer."); |
| static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed, "Static error in basic_layout: type must be of type unsigned integer."); |
| |
| public: |
| using base_type = std::vector<int_type>; |
| using value_type = typename base_type::value_type; |
| using const_reference = typename base_type::const_reference; |
| using reference = typename base_type::reference; |
| using size_type = typename base_type::size_type; |
| using const_pointer = typename base_type::const_pointer; |
| using const_iterator = typename base_type::const_iterator; |
| |
| |
| /** @brief Default constructs basic_extents |
| * |
| * @code auto ex = basic_extents<unsigned>{}; |
| */ |
| constexpr explicit basic_extents() |
| : _base{} |
| { |
| } |
| |
| /** @brief Copy constructs basic_extents from a one-dimensional container |
| * |
| * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) ); |
| * |
| * @note checks if size > 1 and all elements > 0 |
| * |
| * @param b one-dimensional std::vector<int_type> container |
| */ |
| explicit basic_extents(base_type const& b) |
| : _base(b) |
| { |
| if (!this->valid()){ |
| throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements."); |
| } |
| } |
| |
| /** @brief Move constructs basic_extents from a one-dimensional container |
| * |
| * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) ); |
| * |
| * @note checks if size > 1 and all elements > 0 |
| * |
| * @param b one-dimensional container of type std::vector<int_type> |
| */ |
| explicit basic_extents(base_type && b) |
| : _base(std::move(b)) |
| { |
| if (!this->valid()){ |
| throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements."); |
| } |
| } |
| |
| /** @brief Constructs basic_extents from an initializer list |
| * |
| * @code auto ex = basic_extents<unsigned>{3,2,4}; |
| * |
| * @note checks if size > 1 and all elements > 0 |
| * |
| * @param l one-dimensional list of type std::initializer<int_type> |
| */ |
| basic_extents(std::initializer_list<value_type> l) |
| : basic_extents( base_type(std::move(l)) ) |
| { |
| } |
| |
| /** @brief Constructs basic_extents from a range specified by two iterators |
| * |
| * @code auto ex = basic_extents<unsigned>(a.begin(), a.end()); |
| * |
| * @note checks if size > 1 and all elements > 0 |
| * |
| * @param first iterator pointing to the first element |
| * @param last iterator pointing to the next position after the last element |
| */ |
| basic_extents(const_iterator first, const_iterator last) |
| : basic_extents ( base_type( first,last ) ) |
| { |
| } |
| |
| /** @brief Copy constructs basic_extents */ |
| basic_extents(basic_extents const& l ) |
| : _base(l._base) |
| { |
| } |
| |
| /** @brief Move constructs basic_extents */ |
| basic_extents(basic_extents && l ) noexcept |
| : _base(std::move(l._base)) |
| { |
| } |
| |
| ~basic_extents() = default; |
| |
| basic_extents& operator=(basic_extents other) noexcept |
| { |
| swap (*this, other); |
| return *this; |
| } |
| |
| friend void swap(basic_extents& lhs, basic_extents& rhs) { |
| std::swap(lhs._base , rhs._base ); |
| } |
| |
| |
| |
| /** @brief Returns true if this has a scalar shape |
| * |
| * @returns true if (1,1,[1,...,1]) |
| */ |
| bool is_scalar() const |
| { |
| return |
| _base.size() != 0 && |
| std::all_of(_base.begin(), _base.end(), |
| [](const_reference a){ return a == 1;}); |
| } |
| |
| /** @brief Returns true if this has a vector shape |
| * |
| * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1 |
| */ |
| bool is_vector() const |
| { |
| if(_base.size() == 0){ |
| return false; |
| } |
| |
| if(_base.size() == 1){ |
| return _base.at(0) > 1; |
| } |
| |
| auto greater_one = [](const_reference a){ return a > 1;}; |
| auto equal_one = [](const_reference a){ return a == 1;}; |
| |
| return |
| std::any_of(_base.begin(), _base.begin()+2, greater_one) && |
| std::any_of(_base.begin(), _base.begin()+2, equal_one ) && |
| std::all_of(_base.begin()+2, _base.end(), equal_one); |
| } |
| |
| /** @brief Returns true if this has a matrix shape |
| * |
| * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1 |
| */ |
| bool is_matrix() const |
| { |
| if(_base.size() < 2){ |
| return false; |
| } |
| |
| auto greater_one = [](const_reference a){ return a > 1;}; |
| auto equal_one = [](const_reference a){ return a == 1;}; |
| |
| return |
| std::all_of(_base.begin(), _base.begin()+2, greater_one) && |
| std::all_of(_base.begin()+2, _base.end(), equal_one ); |
| } |
| |
| /** @brief Returns true if this is has a tensor shape |
| * |
| * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix() |
| */ |
| bool is_tensor() const |
| { |
| if(_base.size() < 3){ |
| return false; |
| } |
| |
| auto greater_one = [](const_reference a){ return a > 1;}; |
| |
| return std::any_of(_base.begin()+2, _base.end(), greater_one); |
| } |
| |
| const_pointer data() const |
| { |
| return this->_base.data(); |
| } |
| |
| const_reference operator[] (size_type p) const |
| { |
| return this->_base[p]; |
| } |
| |
| const_reference at (size_type p) const |
| { |
| return this->_base.at(p); |
| } |
| |
| reference operator[] (size_type p) |
| { |
| return this->_base[p]; |
| } |
| |
| reference at (size_type p) |
| { |
| return this->_base.at(p); |
| } |
| |
| |
| bool empty() const |
| { |
| return this->_base.empty(); |
| } |
| |
| size_type size() const |
| { |
| return this->_base.size(); |
| } |
| |
| /** @brief Returns true if size > 1 and all elements > 0 */ |
| bool valid() const |
| { |
| return |
| this->size() > 1 && |
| std::none_of(_base.begin(), _base.end(), |
| [](const_reference a){ return a == value_type(0); }); |
| } |
| |
| /** @brief Returns the number of elements a tensor holds with this */ |
| size_type product() const |
| { |
| if(_base.empty()){ |
| return 0; |
| } |
| |
| return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>()); |
| |
| } |
| |
| |
| /** @brief Eliminates singleton dimensions when size > 2 |
| * |
| * squeeze { 1,1} -> { 1,1} |
| * squeeze { 2,1} -> { 2,1} |
| * squeeze { 1,2} -> { 1,2} |
| * |
| * squeeze {1,2,3} -> { 2,3} |
| * squeeze {2,1,3} -> { 2,3} |
| * squeeze {1,3,1} -> { 3,1} |
| * |
| */ |
| basic_extents squeeze() const |
| { |
| if(this->size() <= 2){ |
| return *this; |
| } |
| |
| auto new_extent = basic_extents{}; |
| auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base); |
| std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1}); |
| return new_extent; |
| |
| } |
| |
| void clear() |
| { |
| this->_base.clear(); |
| } |
| |
| bool operator == (basic_extents const& b) const |
| { |
| return _base == b._base; |
| } |
| |
| bool operator != (basic_extents const& b) const |
| { |
| return !( _base == b._base ); |
| } |
| |
| const_iterator |
| begin() const |
| { |
| return _base.begin(); |
| } |
| |
| const_iterator |
| end() const |
| { |
| return _base.end(); |
| } |
| |
| base_type const& base() const { return _base; } |
| |
| private: |
| |
| base_type _base; |
| |
| }; |
| |
| using shape = basic_extents<std::size_t>; |
| |
| } // namespace ublas |
| } // namespace numeric |
| } // namespace boost |
| |
| #endif |