blob: d1d5dc1708373e5a5acb7d90423cbb622befc320 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/shape.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
Shape::Shape(const ShapeProto& shape_proto) {
set_element_type(shape_proto.element_type());
dimensions_.reserve(shape_proto.dimensions_size());
for (const int64 dimension : shape_proto.dimensions()) {
add_dimensions(dimension);
}
// A malformed proto may have different is_dynamic_dimension_size and
// dimensions_size. Since C++ is evil, and we have no good way of bailing out
// in a constructor, conservatively trim the is_dynamic_dimension size.
// TODO(b/120111794): Make this a hard error when we have a factory method
// instead of a constructor.
if (shape_proto.dimensions_size() !=
shape_proto.is_dynamic_dimension_size()) {
if (shape_proto.is_dynamic_dimension_size() != 0) {
LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension "
"fields does not match number of dimension fields";
} else {
LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty";
}
}
int64 num_dynamic_dimension_fields = std::min(
shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size());
for (int i = 0; i < num_dynamic_dimension_fields; i++) {
dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i);
}
tuple_shapes_.reserve(shape_proto.tuple_shapes_size());
for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) {
tuple_shapes_.emplace_back(element_shape);
}
if (shape_proto.has_layout()) {
*mutable_layout() = Layout::CreateFromProto(shape_proto.layout());
}
}
ShapeProto Shape::ToProto() const {
ShapeProto proto;
proto.set_element_type(element_type_);
proto.mutable_dimensions()->Reserve(dimensions_size());
for (const int64 dimension : dimensions()) {
proto.add_dimensions(dimension);
}
for (const bool dynamic : dynamic_dimensions_) {
proto.add_is_dynamic_dimension(dynamic);
}
proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size());
for (const Shape& shape : tuple_shapes()) {
*proto.add_tuple_shapes() = shape.ToProto();
}
if (has_layout()) {
*proto.mutable_layout() = layout().ToProto();
}
return proto;
}
string Shape::ToString(bool print_layout) const {
if (print_layout) {
return ShapeUtil::HumanStringWithLayout(*this);
} else {
return ShapeUtil::HumanString(*this);
}
}
bool Shape::is_static() const {
if (IsTuple()) {
for (const Shape& subshape : tuple_shapes_) {
if (!subshape.is_static()) {
return false;
}
}
}
return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
}
void Shape::DeleteDimension(int64 dim_to_delete) {
CHECK(IsArray());
CHECK_GE(dim_to_delete, 0);
CHECK_LT(dim_to_delete, dimensions_.size());
dimensions_.erase(dimensions_.begin() + dim_to_delete);
dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete);
if (LayoutUtil::HasLayout(*this)) {
layout_.set_format(DENSE);
for (int64 i = 0; i < layout_.minor_to_major().size();) {
if (layout_.minor_to_major(i) == dim_to_delete) {
layout_.mutable_minor_to_major()->erase(
layout_.mutable_minor_to_major()->begin() + i);
continue;
}
if (layout_.minor_to_major(i) > dim_to_delete) {
(*layout_.mutable_minor_to_major())[i] -= 1;
}
++i;
}
}
}
bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
if (lhs.IsTuple()) {
return rhs.IsTuple() &&
absl::c_equal(
lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) { return (*this)(l, r); });
} else if (!lhs.IsArray()) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
// the same.
return lhs.element_type() == rhs.element_type();
}
if (!rhs.IsArray()) {
return false;
}
if (!ignore_element_type_) {
if ((ignore_fp_precision_ &&
!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
(!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) {
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
return false;
}
}
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
if (!ignore_layout_) {
if (lhs.layout().format() != rhs.layout().format()) {
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
return false;
}
if (LayoutUtil::IsDenseArray(lhs)) {
Layout::Equal equal;
if (ignore_tiles_in_layout_) {
equal.IgnoreTiles();
}
if (ignore_element_size_in_layout_) {
equal.IgnoreElementSize();
}
if (ignore_memory_space_in_layout_) {
equal.IgnoreMemorySpace();
}
if (!equal(lhs.layout(), rhs.layout())) {
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
return false;
}
}
}
if (!ignore_dynamic_dimension_) {
for (int i = 0; i < lhs.rank(); ++i) {
if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {
VLOG(3)
<< "CompareShapes: lhs and rhs have different dynamic dimensions.";
return false;
}
}
}
return true;
}
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
out << shape.ToString(/*print_layout=*/true);
return out;
}
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
for (const ShapeProto& shape_proto : program_shape_proto.parameters()) {
*add_parameters() = Shape(shape_proto);
}
*mutable_result() = Shape(program_shape_proto.result());
for (const string& name : program_shape_proto.parameter_names()) {
add_parameter_names(name);
}
}
ProgramShapeProto ProgramShape::ToProto() const {
ProgramShapeProto proto;
for (const Shape& shape : parameters()) {
*proto.add_parameters() = shape.ToProto();
}
*proto.mutable_result() = result().ToProto();
for (const string& name : parameter_names()) {
proto.add_parameter_names(name);
}
return proto;
}
string ProgramShape::ToString() const {
std::vector<string> parameter_strings(parameters_size());
for (int i = 0; i < parameters_size(); ++i) {
parameter_strings[i] = absl::StrCat(
i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
ShapeUtil::HumanString(parameters(i)));
}
return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
ShapeUtil::HumanString(result()));
}
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
out << program_shape.ToString() << "\n";
return out;
}
} // namespace xla