| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/framework/tensor_slice.h" |
| #include <vector> |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/logging.h" |
| |
| namespace tensorflow { |
| |
| TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); } |
| |
| TensorSlice::TensorSlice(const TensorSliceProto& proto) { |
| starts_.reserve(proto.extent_size()); |
| lengths_.reserve(proto.extent_size()); |
| for (const auto& e : proto.extent()) { |
| starts_.push_back(e.start()); |
| lengths_.push_back(GetExtentLength(e)); |
| } |
| } |
| |
| TensorSlice::TensorSlice( |
| std::initializer_list<std::pair<int64, int64>> extents) { |
| starts_.reserve(extents.size()); |
| lengths_.reserve(extents.size()); |
| for (const auto& e : extents) { |
| starts_.push_back(e.first); |
| lengths_.push_back(e.second); |
| } |
| } |
| |
| Status TensorSlice::Parse(const string& str, TensorSlice* slice) { |
| std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty()); |
| slice->starts_.reserve(items.size()); |
| slice->lengths_.reserve(items.size()); |
| for (const string& x : items) { |
| int64 s, l; |
| if (x == "-") { |
| // "everything" |
| s = 0; |
| l = kFullExtent; |
| } else { |
| std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty()); |
| if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) || |
| !strings::safe_strto64(sl[1], &l)) { |
| return errors::InvalidArgument( |
| "Expected a pair of numbers or '-' " |
| "but got '", |
| x, "': string = ", str); |
| } |
| if (s < 0 || l <= 0) { |
| return errors::InvalidArgument( |
| "Expected non-negative start and " |
| "positive length but got start = ", |
| s, ", length = ", l, ": string = ", str); |
| } |
| } |
| slice->starts_.push_back(s); |
| slice->lengths_.push_back(l); |
| } |
| |
| return Status::OK(); |
| } |
| |
| void TensorSlice::Clear() { |
| starts_.clear(); |
| lengths_.clear(); |
| } |
| |
| bool TensorSlice::IsFull() const { |
| for (int d = 0; d < dims(); ++d) { |
| if (!IsFullAt(d)) return false; |
| } |
| return true; |
| } |
| |
| void TensorSlice::SetFullSlice(int dim) { |
| Clear(); |
| starts_.reserve(dim); |
| lengths_.reserve(dim); |
| for (int d = 0; d < dim; ++d) { |
| starts_.push_back(0); |
| lengths_.push_back(kFullExtent); |
| } |
| } |
| |
| void TensorSlice::Extend(int dim) { |
| int old_dim = dims(); |
| DCHECK_LE(old_dim, dim); |
| starts_.resize(dim); |
| lengths_.resize(dim); |
| for (int d = old_dim; d < dim; ++d) { |
| starts_[d] = 0; |
| lengths_[d] = kFullExtent; |
| } |
| } |
| |
| void TensorSlice::AsProto(TensorSliceProto* proto) const { |
| for (int d = 0; d < dims(); ++d) { |
| TensorSliceProto::Extent* e = proto->add_extent(); |
| // We only need to record the explicit slice for non-full slices |
| if (!IsFullAt(d)) { |
| e->set_start(starts_[d]); |
| e->set_length(lengths_[d]); |
| } |
| } |
| } |
| |
| string TensorSlice::DebugString() const { |
| string buffer; |
| bool first = true; |
| for (int d = 0; d < dims(); ++d) { |
| if (!first) { |
| buffer.append(":"); |
| } |
| if (IsFullAt(d)) { |
| buffer.append("-"); |
| } else { |
| strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]); |
| } |
| first = false; |
| } |
| return buffer; |
| } |
| |
| bool TensorSlice::Intersect(const TensorSlice& other, |
| TensorSlice* result) const { |
| // First, if two slices have different ranks, they obviously don't overlap |
| // -- in fact they are not compatible. |
| if (dims() != other.dims()) { |
| return false; |
| } |
| |
| // Setting the result to the right dimension |
| if (result) { |
| result->SetFullSlice(dims()); |
| } |
| // The two slices overlap if they overlap in all dimensions. |
| for (int d = 0; d < dims(); ++d) { |
| if (IsFullAt(d)) { |
| if (result) { |
| result->set_start(d, other.start(d)); |
| result->set_length(d, other.length(d)); |
| } |
| } else if (other.IsFullAt(d)) { |
| if (result) { |
| result->set_start(d, start(d)); |
| result->set_length(d, length(d)); |
| } |
| } else { |
| // If we have an intersection here, it should have a start that is the |
| // max of the two starts and an end that is the min of the two ends. |
| int64 s = std::max(start(d), other.start(d)); |
| int64 l = std::min(end(d), other.end(d)) - s; |
| if (l > 0) { |
| // We have a real intersection |
| if (result) { |
| result->set_start(d, s); |
| result->set_length(d, l); |
| } |
| } else { |
| // We don't have an intersection for this dimension -- thus we don't |
| // have any intersection at all. |
| if (result) { |
| result->Clear(); |
| } |
| return false; |
| } |
| } |
| } |
| // If we are here, we know there is overlap in every dimension. |
| return true; |
| } |
| |
| bool TensorSlice::operator==(const TensorSlice& other) const { |
| return dims() == other.dims() && starts_ == other.starts_ && |
| lengths_ == other.lengths_; |
| } |
| |
| void TensorSlice::ComputeRelative(const TensorSlice& sub, |
| TensorSlice* relative) const { |
| DCHECK_EQ(dims(), sub.dims()); |
| relative->SetFullSlice(dims()); |
| for (int d = 0; d < dims(); ++d) { |
| if (IsFullAt(d)) { |
| relative->set_start(d, sub.start(d)); |
| relative->set_length(d, sub.length(d)); |
| } else { |
| // Otherwise the relative start is the difference between the start of |
| // sub and the start of base |
| relative->set_start(d, sub.start(d) - start(d)); |
| relative->set_length(d, sub.length(d)); |
| } |
| } |
| } |
| |
| void TensorSlice::UpdateToCover(const TensorSlice& other) { |
| DCHECK_EQ(dims(), other.dims()); |
| for (int d = 0; d < dims(); ++d) { |
| if (!IsFullAt(d)) { |
| if (other.IsFullAt(d)) { |
| starts_[d] = 0; |
| lengths_[d] = kFullExtent; |
| } else { |
| const auto new_end = std::max(end(d), other.end(d)); |
| set_start(d, std::min(start(d), other.start(d))); |
| set_length(d, new_end - start(d)); |
| } |
| } |
| } |
| } |
| |
| // static |
| bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) { |
| return extent.has_length_case() == TensorSliceProto::Extent::kLength; |
| } |
| |
| // static |
| int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { |
| if (!HasExtentLength(extent)) return -1; |
| return extent.length(); |
| } |
| |
| Status TensorSlice::SliceTensorShape(const TensorShape& shape, |
| TensorShape* result_shape) const { |
| result_shape->Clear(); |
| // Mismatching ranks: we can't apply the slice at all. |
| if (shape.dims() != dims()) { |
| return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(), |
| ", slice = ", DebugString()); |
| } |
| for (int d = 0; d < dims(); ++d) { |
| if (IsFullAt(d)) { |
| result_shape->AddDim(shape.dim_size(d)); |
| } else { |
| // Check if the extent applies to the dimension |
| if (end(d) <= shape.dim_size(d)) { |
| // Yes: the end is within the range of the dim -- we adjust the result |
| // shape so that its size along this dimension is the length of the |
| // slice. |
| result_shape->AddDim(length(d)); |
| } else { |
| // The extent doesn't apply to the dimension |
| result_shape->Clear(); |
| return errors::Internal("Extent in dimension ", d, |
| " out of bounds: shape = ", shape.DebugString(), |
| ", slice = ", DebugString()); |
| } |
| } |
| } |
| // If we are here, we have successfully applied the shape. |
| return Status::OK(); |
| } |
| |
| const int64 TensorSlice::kFullExtent = -1; |
| |
| } // namespace tensorflow |