| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/framework/common_shape_fns.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/shape_inference.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/util/mirror_pad_mode.h" |
| #include "tensorflow/core/util/padding.h" |
| #include "tensorflow/core/util/strided_slice_op.h" |
| #include "tensorflow/core/util/tensor_format.h" |
| |
| namespace tensorflow { |
| |
| using shape_inference::DimensionHandle; |
| using shape_inference::InferenceContext; |
| using shape_inference::ShapeHandle; |
| using shape_inference::UnchangedShape; |
| |
| namespace { |
| |
| Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack, |
| int32* axis) { |
| TF_RETURN_IF_ERROR(c->GetAttr("axis", axis)); |
| if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) { |
| return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [", |
| -1 * rank_after_pack, ",", rank_after_pack, |
| ")"); |
| } |
| if (*axis < 0) *axis = (rank_after_pack + *axis); |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) { |
| std::vector<int64> ret(num_elements); |
| auto data = tensor->vec<T>(); |
| for (int64 i = 0; i < num_elements; ++i) { |
| ret[i] = data(i); |
| } |
| return ret; |
| } |
| |
| template <typename T> |
| Status PadKnown(InferenceContext* c, ShapeHandle input, |
| const Tensor* paddings_t, int64 num_dims) { |
| // paddings_t is known. |
| std::vector<DimensionHandle> dims(num_dims); |
| auto paddings_data = paddings_t->matrix<T>(); |
| for (int64 i = 0; i < num_dims; ++i) { |
| const T pad0 = paddings_data(i, 0); |
| const T pad1 = paddings_data(i, 1); |
| if (pad0 < 0 || pad1 < 0) { |
| return errors::InvalidArgument("Paddings must be non-negative"); |
| } |
| TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i])); |
| } |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| } |
| |
| Status PadShapeFn(InferenceContext* c) { |
| // Paddings is a matrix of [input_rank, 2]. |
| ShapeHandle paddings; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); |
| DimensionHandle unused; |
| TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused)); |
| |
| // n_dim and input.rank are equivalent. |
| ShapeHandle input = c->input(0); |
| DimensionHandle n_dim = c->Dim(paddings, 0); |
| if (c->ValueKnown(n_dim)) { |
| TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input)); |
| } else if (c->RankKnown(input)) { |
| TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim)); |
| } |
| |
| const Tensor* paddings_t = c->input_tensor(1); |
| |
| // paddings_t is unknown |
| if (paddings_t == nullptr) { |
| if (c->ValueKnown(n_dim)) { |
| // Make output with n_dim unknown dims. |
| c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim))); |
| } else { |
| c->set_output(0, c->UnknownShape()); |
| } |
| return Status::OK(); |
| } |
| |
| const int64 num_dims = paddings_t->shape().dim_size(0); |
| TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input)); |
| TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim)); |
| |
| if (paddings_t->dtype() == DT_INT32) { |
| return PadKnown<int32>(c, input, paddings_t, num_dims); |
| } else { |
| return PadKnown<int64>(c, input, paddings_t, num_dims); |
| } |
| } |
| |
| Status TransposeShapeFn(InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| ShapeHandle perm_shape = c->input(1); |
| const Tensor* perm = c->input_tensor(1); |
| DimensionHandle perm_elems = c->NumElements(perm_shape); |
| // If we don't have rank information on the input or value information on |
| // perm we can't return any shape information, otherwise we have enough |
| // information to at least find the rank of the output. |
| if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| |
| // Find our value of the rank. |
| int64 rank; |
| if (c->RankKnown(input)) { |
| rank = c->Rank(input); |
| } else if (c->ValueKnown(perm_elems)) { |
| rank = c->Value(perm_elems); |
| } else { |
| rank = perm->NumElements(); |
| } |
| std::vector<DimensionHandle> dims; |
| dims.resize(rank); |
| TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); |
| // Ensure that perm is a vector and has rank elements. |
| TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape)); |
| TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems)); |
| |
| // If we know the rank of the input and the value of perm, we can return |
| // all shape informantion, otherwise we can only return rank information, |
| // but no information for the dimensions. |
| if (perm != nullptr) { |
| std::vector<int64> data; |
| if (perm->dtype() == DT_INT32) { |
| data = AsInt64<int32>(perm, rank); |
| } else { |
| data = AsInt64<int64>(perm, rank); |
| } |
| |
| for (int32 i = 0; i < rank; ++i) { |
| int64 in_idx = data[i]; |
| if (in_idx >= rank) { |
| return errors::InvalidArgument("perm dim ", in_idx, |
| " is out of range of input rank ", rank); |
| } |
| dims[i] = c->Dim(input, in_idx); |
| } |
| } else { |
| for (int i = 0; i < rank; ++i) { |
| dims[i] = c->UnknownDim(); |
| } |
| } |
| |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| } |
| |
| Status SetOutputShapeForReshape(InferenceContext* c) { |
| ShapeHandle in = c->input(0); |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); |
| |
| if (!c->RankKnown(out)) { |
| // We have no information about the shape of the output. |
| c->set_output(0, out); |
| return Status::OK(); |
| } |
| |
| if (c->RankKnown(out) && c->RankKnown(in)) { |
| // We don't know the number of output elements, but we can try to infer |
| // the missing dimension. |
| bool too_many_unknown = false; |
| int32 out_unknown_idx = -1; |
| |
| DimensionHandle known_out_elems = c->NumElements(out); |
| if (!c->ValueKnown(known_out_elems)) { |
| known_out_elems = c->MakeDim(1); |
| for (int32 i = 0; i < c->Rank(out); ++i) { |
| DimensionHandle dim = c->Dim(out, i); |
| if (!c->ValueKnown(dim)) { |
| if (out_unknown_idx >= 0) { |
| too_many_unknown = true; |
| break; |
| } |
| out_unknown_idx = i; |
| } else { |
| TF_RETURN_IF_ERROR( |
| c->Multiply(known_out_elems, dim, &known_out_elems)); |
| } |
| } |
| } |
| int32 in_unknown_idx = -1; |
| DimensionHandle known_in_elems = c->NumElements(in); |
| if (!c->ValueKnown(known_in_elems)) { |
| known_in_elems = c->MakeDim(1); |
| for (int32 i = 0; i < c->Rank(in); ++i) { |
| DimensionHandle dim = c->Dim(in, i); |
| if (!c->ValueKnown(dim)) { |
| if (in_unknown_idx >= 0) { |
| too_many_unknown = true; |
| break; |
| } |
| in_unknown_idx = i; |
| } else { |
| TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems)); |
| } |
| } |
| } |
| |
| if (!too_many_unknown) { |
| if (in_unknown_idx < 0 && out_unknown_idx < 0) { |
| // Just check that the dimensions match. |
| if (c->Value(known_in_elems) != c->Value(known_out_elems)) { |
| return errors::InvalidArgument( |
| "Cannot reshape a tensor with ", c->DebugString(known_in_elems), |
| " elements to shape ", c->DebugString(out), " (", |
| c->DebugString(known_out_elems), " elements)"); |
| } |
| } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 && |
| c->Value(known_out_elems) > 0) { |
| // Input fully known, infer the one missing output dim |
| DimensionHandle inferred_dim; |
| TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems), |
| true /* evenly_divisible */, |
| &inferred_dim)); |
| TF_RETURN_IF_ERROR( |
| c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out)); |
| |
| } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 && |
| c->Value(known_in_elems) != 0) { |
| // Output fully known, infer the one missing input dim |
| DimensionHandle inferred_dim; |
| TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems), |
| true /* evenly_divisible */, |
| &inferred_dim)); |
| DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); |
| TF_RETURN_IF_ERROR( |
| c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim)); |
| } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) { |
| // Exactly one unknown dimension in both input and output. These 2 are |
| // equal iff the known elements are equal. |
| if (c->Value(known_in_elems) == c->Value(known_out_elems)) { |
| DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); |
| TF_RETURN_IF_ERROR( |
| c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out)); |
| } |
| } |
| } |
| } |
| c->set_output(0, out); |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| REGISTER_OP("ParallelConcat") |
| .Input("values: N * T") |
| .Output("output: T") |
| .Attr("N: int >= 1") |
| .Attr("T: type") |
| .Attr("shape: shape") |
| .SetShapeFn([](InferenceContext* c) { |
| // Validate that the shape attr is correct. |
| PartialTensorShape shape; |
| TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); |
| ShapeHandle passed_shape; |
| TF_RETURN_IF_ERROR( |
| c->MakeShapeFromPartialTensorShape(shape, &passed_shape)); |
| if (!c->FullyDefined(passed_shape)) { |
| return errors::InvalidArgument("shape attr must be fully defined."); |
| } |
| ShapeHandle cur; |
| TF_RETURN_IF_ERROR(c->ReplaceDim( |
| passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)), |
| &cur)); |
| for (int i = 0; i < c->num_inputs(); ++i) { |
| if (!c->FullyDefined(c->input(i))) { |
| return errors::InvalidArgument( |
| "All input shapes must be fully defined."); |
| } |
| DimensionHandle unused; |
| if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) { |
| return errors::InvalidArgument("Size of first dimension must be 1."); |
| } |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
| "From merging shape ", i, |
| " with other shapes."); |
| } |
| |
| c->set_output(0, passed_shape); |
| |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("Pack") |
| .Input("values: N * T") |
| .Output("output: T") |
| .Attr("N: int >= 1") |
| .Attr("T: type") |
| .Attr("axis: int = 0") |
| .SetShapeFn([](InferenceContext* c) { |
| // Validate shapes of all inputs are compatible |
| ShapeHandle cur = c->input(c->num_inputs() - 1); |
| for (int i = c->num_inputs() - 2; i >= 0; --i) { |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
| "From merging shape ", i, |
| " with other shapes."); |
| } |
| if (!c->RankKnown(cur)) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| // Determine the axis that will be added, converting from negative |
| // axes to a positive point per negative indexing rules. |
| int32 rank = c->Rank(cur); |
| int32 axis; |
| TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis)); |
| |
| // Copy all dimensions over, inserting a dimension of value #inputs |
| // at <axis>. |
| std::vector<DimensionHandle> dims; |
| int index = 0; |
| while (index < axis) dims.push_back(c->Dim(cur, index++)); |
| dims.push_back(c->MakeDim(c->num_inputs())); |
| while (index < rank) dims.push_back(c->Dim(cur, index++)); |
| |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("DeepCopy") |
| .Input("x: T") |
| .Output("y: T") |
| .Attr("T: type") |
| .SetIsStateful() |
| .SetShapeFn(UnchangedShape); |
| |
| REGISTER_OP("InplaceUpdate") |
| .Input("x: T") |
| .Input("i: int32") |
| .Input("v: T") |
| .Output("y: T") |
| .Attr("T: type") |
| .SetShapeFn(UnchangedShape); |
| |
| REGISTER_OP("InplaceAdd") |
| .Input("x: T") |
| .Input("i: int32") |
| .Input("v: T") |
| .Output("y: T") |
| .Attr("T: type") |
| .SetShapeFn(UnchangedShape); |
| |
| REGISTER_OP("InplaceSub") |
| .Input("x: T") |
| .Input("i: int32") |
| .Input("v: T") |
| .Output("y: T") |
| .Attr("T: type") |
| .SetShapeFn(UnchangedShape); |
| |
| REGISTER_OP("Empty") |
| .Input("shape: int32") |
| .Output("output: dtype") |
| .Attr("dtype: type") |
| .Attr("init: bool = false") |
| .SetIsStateful() |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Unpack") |
| .Input("value: T") |
| .Output("output: num * T") |
| .Attr("num: int >= 0") |
| .Attr("T: type") |
| .Attr("axis: int = 0") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle s = c->input(0); |
| ShapeHandle out; |
| if (c->RankKnown(s)) { |
| // Determine the axis that will be removed, converting from negative |
| // axes to a positive point per negative indexing rules. |
| int32 rank = c->Rank(s); |
| int32 axis; |
| TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis)); |
| |
| // The axis dim matches the number of outputs. |
| DimensionHandle unused; |
| TF_RETURN_IF_ERROR( |
| c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused)); |
| |
| // Copy all dimensions, removing the <axis> dimension. |
| std::vector<DimensionHandle> dims; |
| for (int i = 0; i < rank; ++i) { |
| if (i != axis) dims.push_back(c->Dim(s, i)); |
| } |
| out = c->MakeShape(dims); |
| } else { |
| // All outputs are the same shape, but it's not known. |
| out = c->UnknownShape(); |
| } |
| for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("UnravelIndex") |
| .Input("indices: Tidx") |
| .Input("dims: Tidx") |
| .Output("output: Tidx") |
| .Attr("Tidx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle indices = c->input(0); |
| ShapeHandle dims; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); |
| if (c->RankKnown(indices) && c->Rank(indices) == 0) { |
| c->set_output(0, c->Vector(c->Dim(dims, 0))); |
| } else if (c->RankKnown(indices)) { |
| c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices))); |
| } else { |
| c->set_output(0, c->UnknownShape()); |
| } |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("BroadcastTo") |
| .Input("input: T") |
| .Input("shape: Tidx") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tidx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle in = c->input(0); |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); |
| |
| if (!c->RankKnown(out)) { |
| // We have no information about the shape of the output. |
| c->set_output(0, out); |
| return Status::OK(); |
| } |
| |
| if (!c->RankKnown(in)) { |
| // We have no information about the shape of the input, |
| // nothing to do here. |
| c->set_output(0, out); |
| return Status::OK(); |
| } |
| if (c->Rank(out) < c->Rank(in)) { |
| return errors::InvalidArgument("Cannot broadcast a tensor with shape ", |
| c->DebugString(in), " shape ", |
| c->DebugString(out)); |
| } |
| |
| int32 in_offset = c->Rank(out) - c->Rank(in); |
| for (int32 i = 0; i < c->Rank(out); ++i) { |
| DimensionHandle dim = c->Dim(out, i); |
| if (c->ValueKnown(dim)) { |
| // The first in_offset dimensions for input will be expanded with 1, |
| // so no check needed. |
| if (i >= in_offset) { |
| DimensionHandle in_dim = c->Dim(in, i - in_offset); |
| if (c->ValueKnown(in_dim) && c->Value(in_dim) != 0) { |
| if (c->Value(dim) % c->Value(in_dim) != 0) { |
| return errors::InvalidArgument( |
| "Cannot broadcast a tensor with shape ", c->DebugString(in), |
| " shape ", c->DebugString(out)); |
| } |
| } |
| } |
| } |
| } |
| |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph |
| // in the N == 1 case to remove the node. |
| REGISTER_OP("Concat") |
| .Input("concat_dim: int32") |
| .Input("values: N * T") |
| .Output("output: T") |
| .Attr("N: int >= 2") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| return shape_inference::ConcatShape(c, c->num_inputs() - 1); |
| }); |
| |
| REGISTER_OP("ConcatV2") |
| .Input("values: N * T") |
| .Input("axis: Tidx") |
| .Output("output: T") |
| .Attr("N: int >= 2") |
| .Attr("T: type") |
| .Attr("Tidx: {int32, int64} = DT_INT32") |
| .SetShapeFn(shape_inference::ConcatV2Shape); |
| |
| // TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops |
| // are not to be made user-accessible. |
| #ifdef INTEL_MKL |
| REGISTER_OP("_MklConcatV2") |
| .Input("values: N * T") |
| .Input("axis: Tidx") |
| .Input("mkl_values: N * uint8") |
| .Input("mkl_axis: uint8") |
| .Output("output: T") |
| .Output("mkl_output: uint8") |
| .Attr("N: int >= 2") |
| .Attr("T: type") |
| .Attr("Tidx: {int32, int64} = DT_INT32") |
| .SetShapeFn(shape_inference::ConcatV2Shape) |
| .Doc(R"doc( |
| MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation. |
| |
| NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
| expected to invoke these operators. |
| )doc"); |
| #endif |
| |
| REGISTER_OP("ConcatOffset") |
| .Input("concat_dim: int32") |
| .Input("shape: N * int32") |
| .Output("offset: N * int32") |
| .Attr("N: int >= 2") |
| .SetShapeFn([](InferenceContext* c) { |
| for (int i = 1; i < c->num_inputs(); ++i) { |
| c->set_output(i - 1, c->input(i)); |
| } |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Split") |
| .Input("split_dim: int32") |
| .Input("value: T") |
| .Output("output: num_split * T") |
| .Attr("num_split: int >= 1") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| DimensionHandle split_dimension; |
| ShapeHandle input = c->input(1); |
| TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( |
| 0, c->Rank(input), &split_dimension)); |
| int num_split = c->num_outputs(); |
| ShapeHandle out; |
| if (!c->ValueKnown(split_dimension)) { |
| if (c->RankKnown(input)) { |
| out = c->UnknownShapeOfRank(c->Rank(input)); |
| } else { |
| out = c->UnknownShape(); |
| } |
| } else { |
| int64 split_dim = c->Value(split_dimension); |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); |
| DimensionHandle split_dim_size; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| c->Divide(c->Dim(input, split_dim), num_split, |
| true /* evenly_divisible */, &split_dim_size), |
| "Number of ways to split should evenly divide the split dimension"); |
| TF_RETURN_IF_ERROR( |
| c->ReplaceDim(input, split_dim, split_dim_size, &out)); |
| } |
| for (int i = 0; i < num_split; ++i) c->set_output(i, out); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("SplitV") |
| .Input("value: T") |
| .Input("size_splits: Tlen") |
| .Input("split_dim: int32") |
| .Output("output: num_split * T") |
| .Attr("num_split: int >= 1") |
| .Attr("T: type") |
| .Attr("Tlen: {int32, int64} = DT_INT64") |
| .SetShapeFn([](InferenceContext* c) { |
| DimensionHandle split_dimension; |
| ShapeHandle input = c->input(0); |
| TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( |
| 2, c->Rank(input), &split_dimension)); |
| int32 num_outputs = c->num_outputs(); |
| int32 rank = c->Rank(input); |
| ShapeHandle output_shape; |
| const Tensor* size_splits = c->input_tensor(1); |
| if (rank == InferenceContext::kUnknownRank) { |
| // If the rank of input tensor is unknown, then return unknown shapes. |
| // Note that the shape of each output can be different. |
| for (int i = 0; i < num_outputs; ++i) { |
| c->set_output(i, c->UnknownShape()); |
| } |
| } else if (rank == 0) { |
| // Throw error if input is a scalar. |
| return errors::InvalidArgument("Can't split scalars"); |
| } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) { |
| // If split dimension is known, but the sizes are unknown, then |
| // only the split dimension is unknown |
| output_shape = input; |
| for (int i = 0; i < num_outputs; ++i) { |
| TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, |
| c->Value(split_dimension), |
| c->UnknownDim(), &output_shape)); |
| c->set_output(i, output_shape); |
| } |
| } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) { |
| // If split dimension or tensor containing the split sizes is unknown, |
| // then return unknown shapes of same rank as input. Note that each |
| // output shape can be different since splitv doesn't always split |
| // tensors evenly. |
| for (int i = 0; i < num_outputs; ++i) { |
| c->set_output(i, c->UnknownShapeOfRank(rank)); |
| } |
| } else { |
| // Determine the output shape if split dimension and split sizes are |
| // known. |
| int64 split_dim = c->Value(split_dimension); |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); |
| std::vector<int64> data; |
| if (size_splits->dtype() == DT_INT32) { |
| data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0)); |
| } else { |
| data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0)); |
| } |
| if (num_outputs != data.size()) { |
| return errors::InvalidArgument( |
| "Length of size_splits should be equal to num_outputs"); |
| } |
| int64_t total_size = 0; |
| bool has_neg_one = false; |
| for (const auto size : data) { |
| if (size == -1) { |
| if (has_neg_one) { |
| return errors::InvalidArgument( |
| "size_splits can only have one -1"); |
| } |
| has_neg_one = true; |
| } else { |
| total_size += size; |
| } |
| } |
| auto split_dim_size = c->Value(c->Dim(input, split_dim)); |
| // If the sizes of the splits are known, then |
| // make sure that the sizes add up to the expected |
| // dimension size, with the possibility of a -1. |
| // Specify the full output shapes. |
| for (int i = 0; i < num_outputs; ++i) { |
| auto size = data[i]; |
| if (data[i] == -1 && c->ValueKnown(split_dim_size)) { |
| size = split_dim_size - total_size; |
| } |
| TF_RETURN_IF_ERROR( |
| c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape)); |
| c->set_output(i, output_shape); |
| } |
| if (c->ValueKnown(split_dim_size)) { |
| if (has_neg_one ? total_size > split_dim_size |
| : total_size != split_dim_size) { |
| return errors::InvalidArgument( |
| "can't split axis of size ", split_dim_size, |
| " into pieces of size [", str_util::Join(data, ","), "]"); |
| } |
| } |
| } |
| |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Const") |
| .Output("output: dtype") |
| .Attr("value: tensor") |
| .Attr("dtype: type") |
| .SetShapeFn([](InferenceContext* c) { |
| const TensorProto* proto = nullptr; |
| TF_RETURN_IF_ERROR(c->GetAttr("value", &proto)); |
| TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape())); |
| TensorShape shape(proto->tensor_shape()); |
| std::vector<DimensionHandle> dims; |
| dims.reserve(shape.dims()); |
| for (int i = 0; i < shape.dims(); ++i) { |
| dims.push_back(c->MakeDim(shape.dim_size(i))); |
| } |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| }); |
| |
| // Returns a constant tensor on the host. Useful for writing C++ tests |
| // and benchmarks which run on GPU but require arguments pinned to the host. |
| // Used by test::graph::HostConstant. |
| // value: Attr `value` is the tensor to return. |
| REGISTER_OP("HostConst") |
| .Output("output: dtype") |
| .Attr("value: tensor") |
| .Attr("dtype: type") |
| .SetShapeFn(shape_inference::UnknownShape); |
| |
| // -------------------------------------------------------------------------- |
| // TODO(mgubin): Update the doc when the freeze_graph script supports converting |
| // into memmapped format. |
| REGISTER_OP("ImmutableConst") |
| .Attr("dtype: type") |
| .Attr("shape: shape") |
| .Attr("memory_region_name: string") |
| .Output("tensor: dtype") |
| .SetShapeFn(shape_inference::ExplicitShape); |
| |
| REGISTER_OP("GuaranteeConst") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn([](shape_inference::InferenceContext* c) { |
| return UnchangedShape(c); |
| }) |
| // We don't want this to be optimized away. |
| .SetIsStateful(); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("ZerosLike") |
| .Input("x: T") |
| .Output("y: T") |
| .Attr("T: type") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("OnesLike") |
| .Input("x: T") |
| .Output("y: T") |
| .Attr( |
| "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, " |
| "int64, complex64, complex128, bool}") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Diag") |
| .Input("diagonal: T") |
| .Output("output: T") |
| .Attr( |
| "T: {bfloat16, half, float, double, int32, int64, complex64, " |
| "complex128}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle in = c->input(0); |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in)); |
| // Output shape is original concatenated with itself. |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("DiagPart") |
| .Input("input: T") |
| .Output("diagonal: T") |
| .Attr( |
| "T: {bfloat16, half, float, double, int32, int64, complex64, " |
| "complex128}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle in = c->input(0); |
| if (!c->RankKnown(in)) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| // Rank must be even, and result will have rank <rank/2>. |
| const int32 rank = c->Rank(in); |
| if ((rank % 2) != 0 || rank <= 0) { |
| return errors::InvalidArgument( |
| "Input must have even and non-zero rank, input rank is ", rank); |
| } |
| const int32 mid = rank / 2; |
| |
| // output dim[i] is the merge of in.dim[i] and in.dim[i+mid]. |
| std::vector<DimensionHandle> dims(mid); |
| for (int i = 0; i < mid; ++i) { |
| TF_RETURN_IF_ERROR( |
| c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i])); |
| } |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("MatrixDiag") |
| .Input("diagonal: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle in; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in)); |
| if (!c->RankKnown(in)) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| const int32 rank = c->Rank(in); |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR( |
| c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("MatrixSetDiag") |
| .Input("input: T") |
| .Input("diagonal: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input; |
| ShapeHandle diag; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag)); |
| if (c->RankKnown(input)) { |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag)); |
| } |
| DimensionHandle smallest_dim; |
| TF_RETURN_IF_ERROR( |
| c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim)); |
| TF_RETURN_IF_ERROR( |
| c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim)); |
| |
| ShapeHandle output = input; |
| if (c->RankKnown(diag) && !c->FullyDefined(input)) { |
| // Try to infer parts of shape from diag. |
| ShapeHandle diag_prefix; |
| TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix)); |
| TF_RETURN_IF_ERROR( |
| c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag)); |
| TF_RETURN_IF_ERROR(c->Merge(input, diag, &output)); |
| } |
| c->set_output(0, output); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("MatrixDiagPart") |
| .Input("input: T") |
| .Output("diagonal: T") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle in; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in)); |
| if (!c->RankKnown(in)) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| const int32 rank = c->Rank(in); |
| std::vector<DimensionHandle> dims; |
| dims.reserve(rank - 2); |
| for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i)); |
| |
| DimensionHandle min_dim; |
| TF_RETURN_IF_ERROR( |
| c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim)); |
| dims.push_back(min_dim); |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("MatrixBandPart") |
| .Input("input: T") |
| .Input("num_lower: Tindex") |
| .Input("num_upper: Tindex") |
| .Output("band: T") |
| .Attr("T: type") |
| .Attr("Tindex: {int32, int64} = DT_INT64") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Reverse") |
| .Input("tensor: T") |
| .Input("dims: bool") |
| .Output("output: T") |
| .Attr( |
| "T: {uint8, int8, uint16, int16, int32, int64, bool, half, " |
| "float, double, complex64, complex128, string}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| ShapeHandle dims; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); |
| DimensionHandle dims_dim = c->Dim(dims, 0); |
| if (c->ValueKnown(dims_dim)) { |
| TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input)); |
| } |
| if (c->Rank(input) > 8) { |
| return errors::InvalidArgument( |
| "reverse does not work on tensors with more than 8 dimensions"); |
| } |
| c->set_output(0, input); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("ReverseV2") |
| .Input("tensor: T") |
| .Input("axis: Tidx") |
| .Output("output: T") |
| .Attr("Tidx: {int32, int64} = DT_INT32") |
| .Attr( |
| "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, " |
| "float, double, complex64, complex128, string}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| ShapeHandle axis; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis)); |
| if (c->Rank(input) > 8) { |
| return errors::InvalidArgument( |
| "reverse does not work on tensors with more than 8 dimensions"); |
| } |
| const Tensor* axis_tensor = c->input_tensor(1); |
| if (axis_tensor != nullptr && c->RankKnown(input)) { |
| int32 rank = c->Rank(input); |
| std::vector<int64> axis_value; |
| if (axis_tensor->dtype() == DT_INT32) { |
| axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements()); |
| } else { |
| axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements()); |
| } |
| std::vector<bool> axes_dense(c->Rank(input), false); |
| for (int i = 0; i < axis_value.size(); i++) { |
| int64 canonical_axis = |
| axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i]; |
| if (canonical_axis < 0 || canonical_axis >= rank) { |
| return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i], |
| " is out of valid range [", 0, ", ", |
| rank - 1); |
| } |
| if (axes_dense[canonical_axis]) { |
| return errors::InvalidArgument("axis ", canonical_axis, |
| " specified more than once."); |
| } |
| axes_dense[canonical_axis] = true; |
| } |
| } |
| c->set_output(0, input); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("EditDistance") |
| .Input("hypothesis_indices: int64") |
| .Input("hypothesis_values: T") |
| .Input("hypothesis_shape: int64") |
| .Input("truth_indices: int64") |
| .Input("truth_values: T") |
| .Input("truth_shape: int64") |
| .Attr("normalize: bool = true") |
| .Attr("T: type") |
| .Output("output: float") |
| .SetShapeFn([](InferenceContext* c) { |
| TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
| c, c->input(0), c->input(1), c->input(2))); |
| TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
| c, c->input(3), c->input(4), c->input(5))); |
| const Tensor* hypothesis_shape_t = c->input_tensor(2); |
| const Tensor* truth_shape_t = c->input_tensor(5); |
| if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) { |
| // We need to know the runtime shape of the two tensors, |
| // or else the output shape is unknown. |
| return shape_inference::UnknownShape(c); |
| } |
| |
| if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) { |
| return errors::InvalidArgument( |
| "Num elements of hypothesis_shape does not match truth_shape: ", |
| hypothesis_shape_t->NumElements(), " vs. ", |
| truth_shape_t->NumElements()); |
| } |
| |
| auto h_values = hypothesis_shape_t->flat<int64>(); |
| auto t_values = truth_shape_t->flat<int64>(); |
| std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1); |
| for (int i = 0; i < dims.size(); ++i) { |
| dims[i] = c->MakeDim(std::max(h_values(i), t_values(i))); |
| } |
| |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Fill") |
| .Input("dims: index_type") |
| .Input("value: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("index_type: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| DataType index_type = DT_INT32; |
| Status s = c->GetAttr("index_type", &index_type); |
| if (!s.ok() && s.code() != error::NOT_FOUND) { |
| return s; |
| } |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| |
| const Tensor* t = c->input_tensor(0); |
| if (t != nullptr) { |
| for (int i = 0; i < t->NumElements(); ++i) { |
| if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) || |
| (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) { |
| return errors::InvalidArgument("Fill dimensions must be >= 0"); |
| } |
| } |
| } |
| |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("_ParallelConcatStart") |
| .Output("output: dtype") |
| .Attr("shape: shape") |
| .Attr("dtype: type") |
| .SetIsStateful() |
| .SetShapeFn(shape_inference::ExplicitShape) |
| .Doc(R"doc( |
| Creates an empty Tensor with shape `shape` and type `dtype`. |
| |
| The memory can optionally be initialized. This is usually useful in |
| conjunction with inplace operations. |
| |
| shape: 1-D `Tensor` indicating the shape of the output. |
| dtype: The element type of the returned tensor. |
| output: An empty Tensor of the specified type. |
| )doc"); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("_ParallelConcatUpdate") |
| .Input("value: T") |
| .Input("update: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("loc: int") |
| .SetShapeFn(shape_inference::UnchangedShape) |
| .Doc(R"doc( |
| Updates input `value` at `loc` with `update`. |
| |
| If you use this function you will almost certainly want to add |
| a control dependency as done in the implementation of parallel_stack to |
| avoid race conditions. |
| |
| value: A `Tensor` object that will be updated in-place. |
| loc: A scalar indicating the index of the first dimension such that |
| value[loc, :] is updated. |
| update: A `Tensor` of rank one less than `value` if `loc` is a scalar, |
| otherwise of rank equal to `value` that contains the new values |
| for `value`. |
| output: `value` that has been updated accordingly. |
| )doc"); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Gather") |
| .Input("params: Tparams") |
| .Input("indices: Tindices") |
| .Attr("validate_indices: bool = true") |
| .Output("output: Tparams") |
| .Attr("Tparams: type") |
| .Attr("Tindices: {int32,int64}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); |
| ShapeHandle params_subshape; |
| TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, ¶ms_subshape)); |
| ShapeHandle indices_shape = c->input(1); |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("GatherV2") |
| .Input("params: Tparams") |
| .Input("indices: Tindices") |
| .Input("axis: Taxis") |
| .Output("output: Tparams") |
| .Attr("Tparams: type") |
| .Attr("Tindices: {int32,int64}") |
| .Attr("Taxis: {int32,int64}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle params_shape; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, ¶ms_shape)); |
| |
| ShapeHandle indices_shape = c->input(1); |
| ShapeHandle unused_axis_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape)); |
| const Tensor* axis_t = c->input_tensor(2); |
| |
| // If axis is unknown, we can only infer that the result is params_rank + |
| // indices_rank - 1. |
| if (axis_t == nullptr) { |
| if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) { |
| c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) + |
| c->Rank(indices_shape) - 1)); |
| } else { |
| c->set_output(0, c->UnknownShape()); |
| } |
| return Status::OK(); |
| } |
| |
| // Note, axis can be negative. |
| int64 axis = 0; |
| if (axis_t->dtype() == DT_INT32) { |
| axis = axis_t->scalar<int32>()(); |
| } else { |
| axis = axis_t->scalar<int64>()(); |
| } |
| |
| // Check that params has rank of at least axis + 1. |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast( |
| params_shape, axis < 0 ? -axis : axis + 1, &unused)); |
| |
| ShapeHandle params_outer_subshape; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(params_shape, 0, axis, ¶ms_outer_subshape)); |
| |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR( |
| c->Concatenate(params_outer_subshape, indices_shape, &out)); |
| |
| // Slice from axis + 1 to the end of params_shape to collect the inner |
| // dimensions of the result. Special case -1 here since -1 + 1 wraps, and |
| // we slice from 0 to the end of shape. Subshape() handles all other |
| // out-of-bounds checking. |
| if (axis != -1) { |
| ShapeHandle params_inner_subshape; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(params_shape, axis + 1, ¶ms_inner_subshape)); |
| TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out)); |
| } |
| |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("GatherNd") |
| .Input("params: Tparams") |
| .Input("indices: Tindices") |
| .Output("output: Tparams") |
| .Attr("Tparams: type") |
| .Attr("Tindices: {int32,int64}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle params = c->input(0); |
| ShapeHandle indices; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices)); |
| DimensionHandle r_dim = c->Dim(indices, -1); |
| |
| if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| |
| if (c->Value(r_dim) > c->Rank(params)) { |
| return errors::InvalidArgument( |
| "indices.shape[-1] must be <= params.rank, but saw indices shape: ", |
| c->DebugString(indices), |
| " and params shape: ", c->DebugString(params)); |
| } |
| |
| // Remove r_dim from indices to get output. |
| ShapeHandle indices_slice; |
| ShapeHandle params_slice; |
| TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice)); |
| TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice)); |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Identity") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn([](shape_inference::InferenceContext* c) { |
| c->set_output(0, c->input(0)); |
| auto* handle_data = c->input_handle_shapes_and_types(0); |
| if (handle_data != nullptr) { |
| c->set_output_handle_shapes_and_types(0, *handle_data); |
| } |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("Snapshot") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn([](shape_inference::InferenceContext* c) { |
| c->set_output(0, c->input(0)); |
| auto* handle_data = c->input_handle_shapes_and_types(0); |
| if (handle_data != nullptr) { |
| c->set_output_handle_shapes_and_types(0, *handle_data); |
| } |
| return Status::OK(); |
| }); |
| |
| #ifdef INTEL_MKL |
| REGISTER_OP("_MklIdentity") |
| .Input("input: T") |
| .Input("mkl_input: uint8") |
| .Output("output: T") |
| .Output("mkl_output: uint8") |
| .Attr("T: type") |
| .SetShapeFn([](shape_inference::InferenceContext* c) { |
| c->set_output(0, c->input(0)); |
| auto* handle_data = c->input_handle_shapes_and_types(0); |
| if (handle_data != nullptr) { |
| c->set_output_handle_shapes_and_types(0, *handle_data); |
| } |
| return Status::OK(); |
| }) |
| .Doc(R"Doc( Mkl implementation of IdentityOp |
| )Doc"); |
| #endif |
| |
| REGISTER_OP("IdentityN") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: list(type)") |
| .SetShapeFn([](shape_inference::InferenceContext* c) { |
| std::vector<ShapeHandle> input; |
| TF_RETURN_IF_ERROR(c->input("input", &input)); |
| TF_RETURN_IF_ERROR(c->set_output("output", input)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("RefIdentity") |
| .Input("input: Ref(T)") |
| .Output("output: Ref(T)") |
| .Attr("T: type") |
| .SetShapeFn(shape_inference::UnchangedShape) |
| .SetAllowsUninitializedInput(); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("DebugGradientIdentity") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn(shape_inference::UnchangedShape) |
| .SetAllowsUninitializedInput(); |
| |
| REGISTER_OP("DebugGradientRefIdentity") |
| .Input("input: Ref(T)") |
| .Output("output: Ref(T)") |
| .Attr("T: type") |
| .SetShapeFn(shape_inference::UnchangedShape) |
| .SetAllowsUninitializedInput(); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("StopGradient") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| REGISTER_OP("PreventGradient") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("message: string = ''") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("CheckNumerics") |
| .Input("tensor: T") |
| .Output("output: T") |
| .Attr("T: {bfloat16, half, float, double}") |
| .Attr("message: string") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Reshape") |
| .Input("tensor: T") |
| .Input("shape: Tshape") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tshape: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| return SetOutputShapeForReshape(c); |
| }); |
| |
| #ifdef INTEL_MKL |
| REGISTER_OP("_MklReshape") |
| .Input("tensor: T") |
| .Input("shape: Tshape") |
| .Input("mkl_tensor: uint8") |
| .Input("mkl_shape: uint8") |
| .Output("output: T") |
| .Output("mkl_output: uint8") |
| .Attr("T: type") |
| .Attr("Tshape: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); }) |
| .Doc(R"Doc( MKL implementation of ReshapeOp. |
| )Doc"); |
| #endif // INTEL_MKL |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("InvertPermutation") |
| .Input("x: T") |
| .Output("y: T") |
| .Attr("T: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle x; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x)); |
| c->set_output(0, x); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Transpose") |
| .Input("x: T") |
| .Input("perm: Tperm") |
| .Output("y: T") |
| .Attr("T: type") |
| .Attr("Tperm: {int32, int64} = DT_INT32") |
| .SetShapeFn(TransposeShapeFn); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("ConjugateTranspose") |
| .Input("x: T") |
| .Input("perm: Tperm") |
| .Output("y: T") |
| .Attr("T: type") |
| .Attr("Tperm: {int32, int64} = DT_INT32") |
| .SetShapeFn(TransposeShapeFn); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Unique") |
| .Input("x: T") |
| .Output("y: T") |
| .Output("idx: out_idx") |
| .Attr("T: type") |
| .Attr("out_idx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
| c->set_output(1, c->input(0)); |
| // Assert that the input rank is 1. |
| ShapeHandle dummy; |
| return c->WithRank(c->input(0), 1, &dummy); |
| }); |
| |
| REGISTER_OP("UniqueV2") |
| .Input("x: T") |
| .Input("axis: Taxis") |
| .Output("y: T") |
| .Output("idx: out_idx") |
| .Attr("T: type") |
| .Attr("Taxis: {int32,int64} = DT_INT64") |
| .Attr("out_idx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
| c->set_output(1, c->input(0)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("UniqueWithCounts") |
| .Input("x: T") |
| .Output("y: T") |
| .Output("idx: out_idx") |
| .Output("count: out_idx") |
| .Attr("T: type") |
| .Attr("out_idx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| auto uniq = c->Vector(InferenceContext::kUnknownDim); |
| c->set_output(0, uniq); |
| c->set_output(1, c->input(0)); |
| c->set_output(2, uniq); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("UniqueWithCountsV2") |
| .Input("x: T") |
| .Input("axis: Taxis") |
| .Output("y: T") |
| .Output("idx: out_idx") |
| .Output("count: out_idx") |
| .Attr("T: type") |
| .Attr("Taxis: {int32,int64} = DT_INT64") |
| .Attr("out_idx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| auto uniq = c->Vector(InferenceContext::kUnknownDim); |
| c->set_output(0, uniq); |
| c->set_output(1, c->input(0)); |
| c->set_output(2, uniq); |
| return Status::OK(); |
| }); |
| |
| namespace { |
| |
| Status ShapeShapeFn(InferenceContext* c) { |
| for (int i = 0; i < c->num_inputs(); ++i) { |
| DimensionHandle dim; |
| if (c->RankKnown(c->input(i))) { |
| dim = c->MakeDim(c->Rank(c->input(i))); |
| } else { |
| dim = c->UnknownDim(); |
| } |
| c->set_output(i, c->Vector(dim)); |
| } |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Shape") |
| .Input("input: T") |
| .Output("output: out_type") |
| .Attr("T: type") |
| .Attr("out_type: {int32, int64} = DT_INT32") |
| .SetShapeFn(ShapeShapeFn); |
| |
| REGISTER_OP("ShapeN") |
| .Input("input: N * T") |
| .Output("output: N * out_type") |
| .Attr("N: int") |
| .Attr("T: type") |
| .Attr("out_type: {int32, int64} = DT_INT32") |
| .SetShapeFn(ShapeShapeFn); |
| |
| REGISTER_OP("EnsureShape") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("shape: shape") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| // Merges desired shape and statically known shape of input |
| PartialTensorShape desired_shape; |
| TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); |
| |
| int rank = desired_shape.dims(); |
| ShapeHandle input_shape_handle; |
| ShapeHandle desired_shape_handle; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle)); |
| TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( |
| desired_shape, &desired_shape_handle)); |
| |
| ShapeHandle merged_shape; |
| TF_RETURN_IF_ERROR( |
| c->Merge(desired_shape_handle, input_shape_handle, &merged_shape)); |
| c->set_output(0, merged_shape); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("ReverseSequence") |
| .Input("input: T") |
| .Input("seq_lengths: Tlen") |
| .Output("output: T") |
| .Attr("seq_dim: int") |
| .Attr("batch_dim: int = 0") |
| .Attr("T: type") |
| .Attr("Tlen: {int32, int64} = DT_INT64") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| ShapeHandle seq_lens_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape)); |
| |
| int64 seq_dim; |
| TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim)); |
| int64 batch_dim; |
| TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim)); |
| |
| if (!c->RankKnown(input)) { |
| return shape_inference::UnknownShape(c); |
| } |
| |
| // Validate batch_dim and seq_dim against input. |
| const int32 input_rank = c->Rank(input); |
| if (batch_dim >= input_rank) { |
| return errors::InvalidArgument( |
| "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank); |
| } |
| if (seq_dim >= input_rank) { |
| return errors::InvalidArgument( |
| "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank); |
| } |
| |
| DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); |
| TF_RETURN_IF_ERROR( |
| c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim)); |
| |
| // Replace batch_dim of input with batch_size |
| ShapeHandle output_shape; |
| TF_RETURN_IF_ERROR( |
| c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape)); |
| c->set_output(0, output_shape); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Rank") |
| .Input("input: T") |
| .Output("output: int32") |
| .Attr("T: type") |
| .SetShapeFn(shape_inference::ScalarShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Size") |
| .Input("input: T") |
| .Output("output: out_type") |
| .Attr("T: type") |
| .Attr("out_type: {int32, int64} = DT_INT32") |
| .SetShapeFn(shape_inference::ScalarShape); |
| |
| namespace { |
| |
| // This SliceHelper processes the output shape of the `slice` |
| // when the tensor of `sizes` is available. |
| template <typename T> |
| Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, |
| const Tensor* sizes_value, |
| std::vector<DimensionHandle>* dims) { |
| auto sizes_vec = sizes_value->vec<T>(); |
| for (int i = 0; i < sizes_value->NumElements(); ++i) { |
| DimensionHandle dim = c->Dim(c->input(0), i); |
| if (sizes_vec(i) != -1) { |
| auto dim_val = c->Value(dim); |
| if (sizes_vec(i) < 0) { |
| return errors::InvalidArgument( |
| "Out of bounds slicing on dimension ", i, " of length ", dim_val, |
| ": sizes vector cannot be < -1, but was ", sizes_vec(i)); |
| } |
| |
| dims->emplace_back(c->MakeDim(sizes_vec(i))); |
| } else { |
| DimensionHandle result; |
| TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); |
| dims->emplace_back(result); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| } // namespace |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Slice") |
| .Input("input: T") |
| .Input("begin: Index") |
| .Input("size: Index") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Index: {int32,int64}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| ShapeHandle begin_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); |
| ShapeHandle sizes_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); |
| |
| // Merge to check compatibility of begin and sizes tensors. |
| TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); |
| |
| DimensionHandle ndims = c->Dim(begin_shape, 0); |
| if (c->ValueKnown(ndims)) { |
| TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); |
| } |
| |
| // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known |
| // values, even though the `begin` value does not represent a shape. |
| ShapeHandle begin_value; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); |
| |
| // We check the tensor value here and will only use |
| // `MakeShapeFromShapeTensor` when `sizes_value` is null. |
| // The reason is that `sizes`might contain -1, which can't |
| // be represented (-1 in the ShapeHandle would mean "unknown". |
| const Tensor* sizes_value = c->input_tensor(2); |
| |
| if (sizes_value != nullptr) { |
| TF_RETURN_IF_ERROR( |
| c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); |
| std::vector<DimensionHandle> dims; |
| // If the begin and sizes tensors are available, then |
| // we can be precise about the shape of the output. |
| if (sizes_value->dtype() == DT_INT64) { |
| TF_RETURN_IF_ERROR( |
| SliceHelper<int64>(c, begin_value, sizes_value, &dims)); |
| } else { |
| TF_RETURN_IF_ERROR( |
| SliceHelper<int32>(c, begin_value, sizes_value, &dims)); |
| } |
| |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| } else { |
| // In case `sizes` is not available (`sizes_value` is null), |
| // we could try to use `MakeShapeFromShapeTensor` here. |
| // If sizes contain -1, we will simply consider it as `Unknown`. |
| // This is less than ideal but still an improvement of shape inference. |
| // The following is an example that returns [None, 1, None] with this |
| // code path: |
| // z = tf.zeros((1, 2, 3)) |
| // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) |
| // m.get_shape().as_list() |
| ShapeHandle sizes_value; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); |
| if (c->RankKnown(sizes_value)) { |
| TF_RETURN_IF_ERROR( |
| c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); |
| std::vector<DimensionHandle> dims; |
| dims.reserve(c->Rank(sizes_value)); |
| for (int i = 0; i < c->Rank(sizes_value); ++i) { |
| dims.emplace_back(c->Dim(sizes_value, i)); |
| } |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| } |
| |
| // We might know the rank of the input. |
| if (c->RankKnown(input)) { |
| c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); |
| return Status::OK(); |
| } else { |
| return shape_inference::UnknownShape(c); |
| } |
| } |
| |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("StridedSlice") |
| .Input("input: T") |
| .Input("begin: Index") |
| .Input("end: Index") |
| .Input("strides: Index") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Index: {int32, int64}") |
| .Attr("begin_mask: int = 0") |
| .Attr("end_mask: int = 0") |
| .Attr("ellipsis_mask: int = 0") |
| .Attr("new_axis_mask: int = 0") |
| .Attr("shrink_axis_mask: int = 0") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| ShapeHandle begin_shape, end_shape, strides_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape)); |
| TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape)); |
| TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape)); |
| DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0); |
| |
| const Tensor* strides_value = c->input_tensor(3); |
| // TODO(aselle,allenl): If we had a stride_mask it would be possible to do |
| // more shape inference here (e.g. for x[3, ::T]). |
| if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) || |
| strides_value == nullptr) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| |
| PartialTensorShape input_shape({}); |
| for (int i = 0; i < c->Rank(input); ++i) { |
| auto dim = c->Dim(input, i); |
| input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1); |
| } |
| |
| int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask, |
| shrink_axis_mask; |
| TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask)); |
| TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask)); |
| TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask)); |
| TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask)); |
| TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask)); |
| |
| const Tensor* begin_value = c->input_tensor(1); |
| const Tensor* end_value = c->input_tensor(2); |
| |
| PartialTensorShape processing_shape, final_shape; |
| bool is_identity, is_simple_slice, slice_dim0; |
| gtl::InlinedVector<int64, 4> begin, end, strides; |
| TF_RETURN_IF_ERROR(ValidateStridedSliceOp( |
| begin_value, end_value, *strides_value, input_shape, begin_mask, |
| end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, |
| &processing_shape, &final_shape, &is_identity, &is_simple_slice, |
| &slice_dim0, &begin, &end, &strides)); |
| |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out)); |
| c->set_output(0, out); |
| |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("StridedSliceGrad") |
| .Input("shape: Index") |
| .Input("begin: Index") |
| .Input("end: Index") |
| .Input("strides: Index") |
| .Input("dy: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Index: {int32, int64}") |
| .Attr("begin_mask: int = 0") |
| .Attr("end_mask: int = 0") |
| .Attr("ellipsis_mask: int = 0") |
| .Attr("new_axis_mask: int = 0") |
| .Attr("shrink_axis_mask: int = 0") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("StridedSliceAssign") |
| .Input("ref: Ref(T)") |
| .Input("begin: Index") |
| .Input("end: Index") |
| .Input("strides: Index") |
| .Input("value: T") |
| .Output("output_ref: Ref(T)") |
| .Attr("T: type") |
| .Attr("Index: {int32, int64}") |
| .Attr("begin_mask: int = 0") |
| .Attr("end_mask: int = 0") |
| .Attr("ellipsis_mask: int = 0") |
| .Attr("new_axis_mask: int = 0") |
| .Attr("shrink_axis_mask: int = 0") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| // TODO(aselle): Fix this documentation once StridedSliceAssign Supports |
| // broadcasting. |
| // -------------------------------------------------------------------------- |
| |
| REGISTER_OP("ResourceStridedSliceAssign") |
| .Input("ref: resource") |
| .Input("begin: Index") |
| .Input("end: Index") |
| .Input("strides: Index") |
| .Input("value: T") |
| .Attr("T: type") |
| .Attr("Index: {int32, int64}") |
| .Attr("begin_mask: int = 0") |
| .Attr("end_mask: int = 0") |
| .Attr("ellipsis_mask: int = 0") |
| .Attr("new_axis_mask: int = 0") |
| .Attr("shrink_axis_mask: int = 0") |
| .SetShapeFn(shape_inference::NoOutputs); |
| |
| REGISTER_OP("Tile") |
| .Input("input: T") |
| .Input("multiples: Tmultiples") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tmultiples: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i) |
| // it is a vector of non-negative integers, and (ii) doing so allows |
| // us to handle partially-known multiples. |
| ShapeHandle multiples; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples)); |
| if (c->RankKnown(input)) { |
| TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples)); |
| ShapeHandle dummy; |
| TF_RETURN_IF_ERROR( |
| c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy)); |
| } |
| |
| if (!c->RankKnown(multiples)) { |
| return shape_inference::UnknownShape(c); |
| } |
| |
| int32 rank = c->Rank(multiples); |
| TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); |
| std::vector<DimensionHandle> dims(rank); |
| for (int i = 0; i < rank; ++i) { |
| TF_RETURN_IF_ERROR( |
| c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i])); |
| } |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("TileGrad") |
| .Input("input: T") |
| .Input("multiples: int32") |
| .Output("output: T") |
| .Attr("T: type") |
| .Deprecated(3, "TileGrad has been replaced with reduce_sum") |
| .SetShapeFn(tensorflow::shape_inference::UnknownShape); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Where") |
| .Input("input: T") |
| .Attr("T: {numbertype, bool} = DT_BOOL") |
| .Output("index: int64") |
| .SetShapeFn([](InferenceContext* c) { |
| c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0)))); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("BroadcastArgs") |
| .Input("s0: T") |
| .Input("s1: T") |
| .Output("r0: T") |
| .Attr("T: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused; |
| ShapeHandle shape_x = c->input(0); |
| ShapeHandle shape_y = c->input(1); |
| TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused)); |
| |
| if (!c->ValueKnown(c->Dim(shape_x, 0)) || |
| !c->ValueKnown(c->Dim(shape_y, 0))) { |
| c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
| return Status::OK(); |
| } |
| |
| int64 x_dim = c->Value(c->Dim(shape_x, 0)); |
| int64 y_dim = c->Value(c->Dim(shape_y, 0)); |
| |
| // Broadcasted shape is going to be as large as the largest dimension. |
| c->set_output(0, c->Vector(std::max(x_dim, y_dim))); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("BroadcastGradientArgs") |
| .Input("s0: T") |
| .Input("s1: T") |
| .Output("r0: T") |
| .Output("r1: T") |
| .Attr("T: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| // TODO(mrry): Implement constant_value for BroadcastGradientArgs? |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
| c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
| c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Pad") |
| .Input("input: T") |
| .Input("paddings: Tpaddings") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tpaddings: {int32, int64} = DT_INT32") |
| .SetShapeFn(PadShapeFn); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("PadV2") |
| .Input("input: T") |
| .Input("paddings: Tpaddings") |
| .Input("constant_values: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tpaddings: {int32, int64} = DT_INT32") |
| .SetShapeFn(PadShapeFn); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("MirrorPad") |
| .Input("input: T") |
| .Input("paddings: Tpaddings") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tpaddings: {int32, int64} = DT_INT32") |
| .Attr(GetMirrorPadModeAttrString()) |
| .SetShapeFn(PadShapeFn); |
| |
| // -------------------------------------------------------------------------- |
| namespace { |
| template <typename T> |
| Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, |
| const Tensor* paddings_t, int64 input_rank) { |
| auto paddings_data = paddings_t->matrix<T>(); |
| std::vector<DimensionHandle> dims(input_rank); |
| for (int64 i = 0; i < input_rank; ++i) { |
| const int64 pad0 = static_cast<int64>(paddings_data(i, 0)); |
| const int64 pad1 = static_cast<int64>(paddings_data(i, 1)); |
| if (pad0 < 0 || pad1 < 0) { |
| return errors::InvalidArgument("Paddings must be non-negative"); |
| } |
| |
| TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i])); |
| } |
| c->set_output(0, c->MakeShape(dims)); |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| REGISTER_OP("MirrorPadGrad") |
| .Input("input: T") |
| .Input("paddings: Tpaddings") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tpaddings: {int32, int64} = DT_INT32") |
| .Attr(GetMirrorPadModeAttrString()) |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle paddings; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); |
| DimensionHandle pad_0 = c->Dim(paddings, 0); |
| if (!c->ValueKnown(pad_0)) { |
| // We don't know the rank of the output since the first |
| // padding dimension is unknown. |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| |
| int64 input_rank = c->Value(pad_0); |
| ShapeHandle input; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input)); |
| TF_RETURN_IF_ERROR( |
| c->Merge(paddings, c->Matrix(input_rank, 2), &paddings)); |
| |
| const Tensor* paddings_t = c->input_tensor(1); |
| if (paddings_t == nullptr) { |
| // Values of 'paddings' is not available, but we know the |
| // input rank, so return the rank of the output with unknown |
| // dimensions. |
| c->set_output(0, c->UnknownShapeOfRank(input_rank)); |
| return Status::OK(); |
| } |
| |
| if (paddings_t->dtype() == DT_INT32) { |
| return MirrorPadKnown<int32>(c, input, paddings_t, input_rank); |
| } else { |
| return MirrorPadKnown<int64>(c, input, paddings_t, input_rank); |
| } |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Placeholder") |
| .Output("output: dtype") |
| .Attr("dtype: type") |
| .Attr("shape: shape = { unknown_rank: true }") |
| .SetShapeFn([](InferenceContext* c) { |
| PartialTensorShape shape; |
| TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); |
| |
| // Placeholder has legacy behavior where we cannot tell the difference |
| // between a scalar shape attribute and 'unknown shape'. So if the shape |
| // is a scalar, we return an unknown shape. |
| if (c->graph_def_version() <= 21 && shape.dims() <= 0) { |
| return shape_inference::UnknownShape(c); |
| } |
| |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // Placeholder was modified in a backwards compatible way to do what |
| // PlaceholderV2 did, so we have deprecated V2 (no one was really |
| // using it). |
| REGISTER_OP("PlaceholderV2") |
| .Output("output: dtype") |
| .Attr("dtype: type") |
| .Attr("shape: shape") |
| .SetShapeFn(shape_inference::ExplicitShape) |
| .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2."); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("PlaceholderWithDefault") |
| .Input("input: dtype") |
| .Output("output: dtype") |
| .Attr("dtype: type") |
| .Attr("shape: shape") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| PartialTensorShape shape; |
| TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); |
| |
| // We merge for compatibility checking, but return the output, |
| // since output_shape may be less precise than input_shape. |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->Merge(input, out, &unused)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("ExpandDims") |
| .Input("input: T") |
| .Input("dim: Tdim") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tdim: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| |
| const Tensor* dim_t = c->input_tensor(1); |
| if (dim_t != nullptr && dim_t->NumElements() != 1) { |
| return errors::InvalidArgument( |
| "'dim' input must be a tensor with a single value"); |
| } |
| if (dim_t == nullptr || !c->RankKnown(input)) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| |
| int64 dim; |
| if (dim_t->dtype() == DT_INT32) { |
| dim = static_cast<int64>(dim_t->flat<int32>()(0)); |
| } else { |
| dim = dim_t->flat<int64>()(0); |
| } |
| |
| const int32 rank = c->Rank(input); |
| const int32 min_dim = -1 * rank - 1; |
| if (dim < min_dim || dim > rank) { |
| return errors::InvalidArgument("dim ", dim, " not in the interval [", |
| min_dim, ", ", rank, "]."); |
| } |
| |
| if (dim < 0) { |
| dim += rank + 1; |
| } |
| |
| ShapeHandle end; |
| TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end)); |
| |
| // Build output as start + 1 + end. |
| ShapeHandle output; |
| TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output)); |
| TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output)); |
| TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output)); |
| c->set_output(0, output); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("Squeeze") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("squeeze_dims: list(int) >= 0 = []") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| if (!c->RankKnown(input)) { |
| // Input shape unknown. |
| return shape_inference::UnknownShape(c); |
| } |
| |
| const int32 input_rank = c->Rank(input); |
| |
| // Validate and wrap squeeze dimensions. |
| std::vector<int32> squeeze_dims; |
| TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims)); |
| for (int i = 0; i < squeeze_dims.size(); ++i) { |
| if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) { |
| return errors::InvalidArgument("squeeze_dims[", i, "] not in [", |
| -input_rank, ",", input_rank, ")."); |
| } |
| |
| if (squeeze_dims[i] < 0) { |
| squeeze_dims[i] += input_rank; |
| } |
| } |
| |
| std::vector<DimensionHandle> result_shape; |
| for (int i = 0; i < input_rank; ++i) { |
| // True if squeeze_dims contains an entry to squeeze this |
| // dimension. |
| bool is_explicit_match = |
| std::find(squeeze_dims.begin(), squeeze_dims.end(), i) != |
| squeeze_dims.end(); |
| |
| DimensionHandle dim = c->Dim(input, i); |
| |
| if (!c->ValueKnown(dim)) { |
| // Assume that the squeezed dimension will be 1 at runtime. |
| if (is_explicit_match) continue; |
| |
| // If squeezing all 1 dimensions, and we see an unknown value, |
| // give up and return Unknown Shape. |
| if (squeeze_dims.empty()) { |
| c->set_output(0, c->UnknownShape()); |
| return Status::OK(); |
| } |
| } else if (c->Value(dim) == 1) { |
| if (is_explicit_match || squeeze_dims.empty()) { |
| // If explicitly squeezing, or squeezing all 1s, remove |
| // this dimension. |
| continue; |
| } |
| } else if (is_explicit_match) { |
| return errors::InvalidArgument("Can not squeeze dim[", i, |
| "], expected a dimension of 1, got ", |
| c->Value(c->Dim(input, i))); |
| } |
| |
| result_shape.emplace_back(dim); |
| } |
| |
| c->set_output(0, c->MakeShape(result_shape)); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("ListDiff") |
| .Input("x: T") |
| .Input("y: T") |
| .Output("out: T") |
| .Output("idx: out_idx") |
| .Attr("T: type") |
| .Attr("out_idx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
| // TODO(mrry): Indicate that the length falls within an interval? |
| ShapeHandle out = c->Vector(InferenceContext::kUnknownDim); |
| c->set_output(0, out); |
| c->set_output(1, out); |
| return Status::OK(); |
| }); |
| |
| namespace { |
| |
| // Converts Tensor to flat std::vector<int64>. |
| template <typename InputType> |
| std::vector<int64> GetFlatInt64(const Tensor& t) { |
| std::vector<int64> output(t.shape().num_elements()); |
| auto eigen_vec = t.flat<InputType>(); |
| std::copy_n(&eigen_vec(0), output.size(), output.begin()); |
| return output; |
| } |
| |
| // Converts int32 or int64 Tensor to flat std::vector<int64>. |
| std::vector<int64> GetFlatInt64(const Tensor& t) { |
| if (t.dtype() == DT_INT32) { |
| return GetFlatInt64<int32>(t); |
| } else { |
| return GetFlatInt64<int64>(t); |
| } |
| } |
| |
| Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape, |
| ShapeHandle block_shape_shape, |
| const Tensor* block_shape_t, |
| ShapeHandle paddings_shape, |
| const Tensor* paddings_t) { |
| if (c->Rank(block_shape_shape) != 1) { |
| return errors::InvalidArgument("block_shape must have rank 1."); |
| } |
| |
| const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0); |
| if (!c->ValueKnown(num_block_dims_handle)) { |
| return errors::InvalidArgument("block_shape must have known size."); |
| } |
| |
| const int64 num_block_dims = c->Value(num_block_dims_handle); |
| |
| TF_RETURN_IF_ERROR( |
| c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape)); |
| |
| TF_RETURN_IF_ERROR( |
| c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape)); |
| |
| DimensionHandle batch_size = c->Dim(input_shape, 0); |
| std::vector<int64> block_shape_vec; |
| if (block_shape_t) { |
| block_shape_vec = GetFlatInt64(*block_shape_t); |
| for (int64 dim = 0; dim < num_block_dims; ++dim) { |
| const int64 block_shape_value = block_shape_vec[dim]; |
| if (block_shape_value < 1) { |
| return errors::InvalidArgument("block_shape must be positive"); |
| } |
| if (c->ValueKnown(batch_size)) { |
| TF_RETURN_IF_ERROR( |
| c->Multiply(batch_size, block_shape_value, &batch_size)); |
| } else { |
| batch_size = c->UnknownDim(); |
| } |
| } |
| } else if (num_block_dims > 0) { |
| batch_size = c->UnknownDim(); |
| } |
| |
| std::vector<DimensionHandle> output_dims{batch_size}; |
| output_dims.resize(num_block_dims + 1, c->UnknownDim()); |
| |
| if (paddings_t) { |
| const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t); |
| for (int64 dim = 0; dim < num_block_dims; ++dim) { |
| const int64 pad_start = paddings_vec[dim * 2], |
| pad_end = paddings_vec[dim * 2 + 1]; |
| if (pad_start < 0 || pad_end < 0) { |
| return errors::InvalidArgument("paddings cannot be negative"); |
| } |
| if (block_shape_t) { |
| DimensionHandle padded_size; |
| TF_RETURN_IF_ERROR( |
| c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size)); |
| TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size)); |
| TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim], |
| /*evenly_divisible=*/true, |
| &output_dims[dim + 1])); |
| } |
| } |
| } |
| |
| ShapeHandle remaining_input_shape; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape)); |
| |
| ShapeHandle result; |
| TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims), |
| remaining_input_shape, &result)); |
| c->set_output(0, result); |
| return Status::OK(); |
| } |
| |
| Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape, |
| ShapeHandle block_shape_shape, |
| const Tensor* block_shape_t, |
| ShapeHandle crops_shape, const Tensor* crops_t) { |
| if (c->Rank(block_shape_shape) != 1) { |
| return errors::InvalidArgument("block_shape must have rank 1."); |
| } |
| |
| const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0); |
| if (!c->ValueKnown(num_block_dims_handle)) { |
| return errors::InvalidArgument("block_shape must have known size."); |
| } |
| |
| const int64 num_block_dims = c->Value(num_block_dims_handle); |
| |
| TF_RETURN_IF_ERROR( |
| c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape)); |
| |
| TF_RETURN_IF_ERROR( |
| c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape)); |
| |
| DimensionHandle batch_size = c->Dim(input_shape, 0); |
| std::vector<int64> block_shape_vec; |
| if (block_shape_t) { |
| block_shape_vec = GetFlatInt64(*block_shape_t); |
| for (int64 dim = 0; dim < num_block_dims; ++dim) { |
| const int64 block_shape_value = block_shape_vec[dim]; |
| if (block_shape_value < 1) { |
| return errors::InvalidArgument("block_shape must be positive"); |
| } |
| if (c->ValueKnown(batch_size)) { |
| TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value, |
| /*evenly_divisible=*/true, &batch_size)); |
| } else { |
| batch_size = c->UnknownDim(); |
| } |
| } |
| } else if (num_block_dims > 0) { |
| batch_size = c->UnknownDim(); |
| } |
| |
| std::vector<DimensionHandle> output_dims{batch_size}; |
| output_dims.resize(num_block_dims + 1, c->UnknownDim()); |
| |
| if (crops_t) { |
| const std::vector<int64> crops_vec = GetFlatInt64(*crops_t); |
| for (int64 dim = 0; dim < num_block_dims; ++dim) { |
| const int64 crop_start = crops_vec[dim * 2], |
| crop_end = crops_vec[dim * 2 + 1]; |
| if (crop_start < 0 || crop_end < 0) { |
| return errors::InvalidArgument("crops cannot be negative"); |
| } |
| if (block_shape_t) { |
| DimensionHandle cropped_size; |
| TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1), |
| block_shape_vec[dim], &cropped_size)); |
| TF_RETURN_IF_ERROR( |
| c->Subtract(cropped_size, crop_start, &cropped_size)); |
| TF_RETURN_IF_ERROR( |
| c->Subtract(cropped_size, crop_end, &output_dims[dim + 1])); |
| } |
| } |
| } |
| |
| ShapeHandle remaining_input_shape; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape)); |
| |
| ShapeHandle result; |
| TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims), |
| remaining_input_shape, &result)); |
| c->set_output(0, result); |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("SpaceToBatchND") |
| .Input("input: T") |
| .Input("block_shape: Tblock_shape") |
| .Input("paddings: Tpaddings") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tblock_shape: {int32, int64} = DT_INT32") |
| .Attr("Tpaddings: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| return SpaceToBatchShapeHelper(c, c->input(0), c->input(1), |
| c->input_tensor(1), c->input(2), |
| c->input_tensor(2)); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("SpaceToBatch") |
| .Input("input: T") |
| .Input("paddings: Tpaddings") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tpaddings: {int32, int64} = DT_INT32") |
| .Attr("block_size: int >= 2") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
| |
| int32 block_size; |
| TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); |
| |
| Tensor block_shape(tensorflow::DT_INT64, TensorShape({2})); |
| auto block_shape_vec = block_shape.vec<int64>(); |
| block_shape_vec(0) = block_size; |
| block_shape_vec(1) = block_size; |
| |
| return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}), |
| &block_shape, c->input(1), |
| c->input_tensor(1)); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("BatchToSpaceND") |
| .Input("input: T") |
| .Input("block_shape: Tblock_shape") |
| .Input("crops: Tcrops") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tblock_shape: {int32, int64} = DT_INT32") |
| .Attr("Tcrops: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| return BatchToSpaceShapeHelper(c, c->input(0), c->input(1), |
| c->input_tensor(1), c->input(2), |
| c->input_tensor(2)); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("BatchToSpace") |
| .Input("input: T") |
| .Input("crops: Tidx") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("block_size: int >= 2") |
| .Attr("Tidx: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
| |
| int32 block_size; |
| TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); |
| |
| Tensor block_shape(tensorflow::DT_INT64, TensorShape({2})); |
| auto block_shape_vec = block_shape.vec<int64>(); |
| block_shape_vec(0) = block_size; |
| block_shape_vec(1) = block_size; |
| |
| return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}), |
| &block_shape, c->input(1), |
| c->input_tensor(1)); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("SpaceToDepth") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("block_size: int >= 2") |
| .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") |
| // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C. |
| .SetShapeFn([](InferenceContext* c) { |
| string data_format_str; |
| TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); |
| TensorFormat data_format; |
| FormatFromString(data_format_str, &data_format); |
| |
| constexpr int num_spatial_dims = 2; |
| const int dims = |
| GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); |
| ShapeHandle input; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input)); |
| |
| int32 block_size; |
| TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); |
| |
| DimensionHandle batch_size = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
| DimensionHandle input_height = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
| DimensionHandle input_width = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
| DimensionHandle input_depth = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
| |
| DimensionHandle output_height; |
| DimensionHandle output_width; |
| DimensionHandle output_depth; |
| // Will return an error if input height or width are not evenly divisible. |
| TF_RETURN_IF_ERROR(c->Divide(input_height, block_size, |
| true /* evenly_divisible */, |
| &output_height)); |
| TF_RETURN_IF_ERROR(c->Divide(input_width, block_size, |
| true /* evenly_divisible */, &output_width)); |
| |
| TF_RETURN_IF_ERROR( |
| c->Multiply(input_depth, block_size * block_size, &output_depth)); |
| |
| ShapeHandle output_shape; |
| TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size, |
| {output_height, output_width}, |
| output_depth, &output_shape, c)); |
| |
| c->set_output(0, output_shape); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| REGISTER_OP("DepthToSpace") |
| .Input("input: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("block_size: int >= 2") |
| .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") |
| // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C. |
| .SetShapeFn([](InferenceContext* c) { |
| string data_format_str; |
| TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); |
| TensorFormat data_format; |
| FormatFromString(data_format_str, &data_format); |
| |
| constexpr int num_spatial_dims = 2; |
| const int dims = |
| GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); |
| |
| ShapeHandle input; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input)); |
| |
| int32 block_size; |
| TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); |
| |
| DimensionHandle batch_size = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
| DimensionHandle input_height = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
| DimensionHandle input_width = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
| DimensionHandle input_depth = |
| c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
| |
| DimensionHandle output_height; |
| DimensionHandle output_width; |
| DimensionHandle output_depth; |
| TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height)); |
| TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width)); |
| |
| // Will return an error if input_depth is not evenly divisible. |
| TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size, |
| true /* evenly_divisible */, &output_depth)); |
| |
| ShapeHandle output_shape; |
| TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size, |
| {output_height, output_width}, |
| output_depth, &output_shape, c)); |
| |
| c->set_output(0, output_shape); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| |
| REGISTER_OP("ExtractImagePatches") |
| .Input("images: T") |
| .Output("patches: T") |
| .Attr("ksizes: list(int) >= 4") |
| .Attr("strides: list(int) >= 4") |
| .Attr("rates: list(int) >= 4") |
| .Attr("T: realnumbertype") |
| .Attr(GetPaddingAttrString()) |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
| |
| std::vector<int32> ksizes; |
| TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes)); |
| if (ksizes.size() != 4) { |
| return errors::InvalidArgument( |
| "ExtractImagePatches requires the ksizes attribute to contain 4 " |
| "values, but got: ", |
| ksizes.size()); |
| } |
| |
| std::vector<int32> strides; |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); |
| if (strides.size() != 4) { |
| return errors::InvalidArgument( |
| "ExtractImagePatches requires the stride attribute to contain 4 " |
| "values, but got: ", |
| strides.size()); |
| } |
| |
| std::vector<int32> rates; |
| TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); |
| if (rates.size() != 4) { |
| return errors::InvalidArgument( |
| "ExtractImagePatches requires the rates attribute to contain 4 " |
| "values, but got: ", |
| rates.size()); |
| } |
| |
| int32 ksize_rows = ksizes[1]; |
| int32 ksize_cols = ksizes[2]; |
| |
| int32 stride_rows = strides[1]; |
| int32 stride_cols = strides[2]; |
| |
| int32 rate_rows = rates[1]; |
| int32 rate_cols = rates[2]; |
| |
| int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); |
| int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); |
| |
| DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
| DimensionHandle in_rows_dim = c->Dim(input_shape, 1); |
| DimensionHandle in_cols_dim = c->Dim(input_shape, 2); |
| DimensionHandle output_depth_dim; |
| TF_RETURN_IF_ERROR(c->Multiply( |
| c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim)); |
| |
| if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) { |
| ShapeHandle output_shape = |
| c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, |
| InferenceContext::kUnknownDim, output_depth_dim}); |
| c->set_output(0, output_shape); |
| return Status::OK(); |
| } |
| auto in_rows = c->Value(in_rows_dim); |
| auto in_cols = c->Value(in_cols_dim); |
| |
| Padding padding; |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); |
| |
| int64 output_rows, output_cols; |
| int64 padding_before, padding_after; |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
| in_rows, ksize_rows_eff, stride_rows, padding, &output_rows, |
| &padding_before, &padding_after)); |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
| in_cols, ksize_cols_eff, stride_cols, padding, &output_cols, |
| &padding_before, &padding_after)); |
| ShapeHandle output_shape = c->MakeShape( |
| {batch_size_dim, output_rows, output_cols, output_depth_dim}); |
| c->set_output(0, output_shape); |
| return Status::OK(); |
| }); |
| |
| // -------------------------------------------------------------------------- |
| |
| REGISTER_OP("Bitcast") |
| .Input("input: T") |
| .Output("output: type") |
| // All supported dtypes are listed here to include qint16, quint16, uint32, |
| // and uint64. |
| .Attr( |
| "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, " |
| "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, " |
| "qint16, quint16, qint32}") |
| .Attr( |
| "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, " |
| "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, " |
| "qint16, quint16, qint32}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input = c->input(0); |
| if (!c->RankKnown(input)) { |
| // Input shape unknown. |
| return shape_inference::UnknownShape(c); |
| } |
| |
| // Find the size of the input and output data types. |
| DataType input_type; |
| DataType output_type; |
| TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type)); |
| TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type)); |
| const int input_type_size = DataTypeSize(input_type); |
| const int output_type_size = DataTypeSize(output_type); |
| |
| if (input_type_size == 0 || output_type_size == 0) { |
| return errors::InvalidArgument("Cannot bitcast types ", |
| DataTypeString(input_type), " to ", |
| DataTypeString(output_type), |
| " because " |
| "one of the type sizes is zero."); |
| } |
| |
| ShapeHandle new_shape; |
| if (input_type_size == output_type_size) { |
| // No change in size. |
| new_shape = input; |
| } else if (input_type_size < output_type_size) { |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape)); |
| |
| int64 divisor_val = output_type_size / input_type_size; |
| DimensionHandle last_dim = c->Dim(new_shape, -1); |
| if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) { |
| TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape)); |
| } else { |
| return errors::InvalidArgument("Cannot bitcast due to shape. ", |
| c->Value(last_dim), " does not match ", |
| divisor_val); |
| } |
| } else { |
| // Input type size is larger than output type size. |
| int64 divisor_val = input_type_size / output_type_size; |
| ShapeHandle extension = c->Vector(divisor_val); |
| TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape)); |
| } |
| |
| c->set_output(0, new_shape); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("OneHot") |
| .Input("indices: TI") |
| .Input("depth: int32") |
| .Input("on_value: T") |
| .Input("off_value: T") |
| .Attr("axis: int = -1") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("TI: {uint8, int32, int64} = DT_INT64") |
| .SetShapeFn([](InferenceContext* c) { |
| int32 axis; |
| TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); |
| if (axis < -1) return errors::InvalidArgument("axis must be >= -1"); |
| |
| DimensionHandle depth; |
| TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth)); |
| |
| ShapeHandle indices = c->input(0); |
| if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c); |
| |
| int32 new_rank = c->Rank(indices) + 1; |
| // We need to add new_rank to axis in the case the axis is -1 because |
| // C++ returns negative values from % if the dividend is negative. |
| int32 depth_index = (axis + new_rank) % new_rank; |
| // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:]. |
| ShapeHandle front; |
| ShapeHandle back; |
| ShapeHandle out; |
| TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front)); |
| TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back)); |
| TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front)); |
| TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out)); |
| c->set_output(0, out); |
| return Status::OK(); |
| }); |
| |
| // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. |
| REGISTER_OP("QuantizeAndDequantize") |
| .Input("input: T") |
| .Attr("signed_input: bool = true") |
| .Attr("num_bits: int = 8") |
| .Attr("range_given: bool = false") |
| .Attr("input_min: float = 0") |
| .Attr("input_max: float = 0") |
| .Output("output: T") |
| .Attr("T: {bfloat16, half, float, double}") |
| .SetShapeFn(shape_inference::UnchangedShape) |
| .Deprecated(22, "Replaced by QuantizeAndDequantizeV2"); |
| |
| // TODO(suharshs): Deprecate QuantizeAndDequantizeV2. |
| REGISTER_OP("QuantizeAndDequantizeV2") |
| .Input("input: T") |
| .Input("input_min: T") |
| .Input("input_max: T") |
| .Attr("signed_input: bool = true") |
| .Attr("num_bits: int = 8") |
| .Attr("range_given: bool = false") |
| .Output("output: T") |
| .Attr("T: {bfloat16, half, float, double}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| c->set_output(0, c->input(0)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("QuantizeAndDequantizeV3") |
| .Input("input: T") |
| .Input("input_min: T") |
| .Input("input_max: T") |
| .Input("num_bits: int32") |
| .Attr("signed_input: bool = true") |
| .Attr("range_given: bool = true") |
| .Output("output: T") |
| .Attr("T: {bfloat16, half, float, double}") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
| c->set_output(0, c->input(0)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("QuantizeV2") |
| .Input("input: float") |
| .Input("min_range: float") |
| .Input("max_range: float") |
| .Output("output: T") |
| .Output("output_min: float") |
| .Output("output_max: float") |
| .Attr("T: quantizedtype") |
| .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'") |
| .Attr( |
| "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = " |
| "'HALF_AWAY_FROM_ZERO'") |
| .SetShapeFn([](InferenceContext* c) { |
| TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| c->set_output(1, c->Scalar()); |
| c->set_output(2, c->Scalar()); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("Dequantize") |
| .Input("input: T") |
| .Input("min_range: float") |
| .Input("max_range: float") |
| .Output("output: float") |
| .Attr("T: quantizedtype") |
| .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'") |
| .SetShapeFn([](InferenceContext* c) { |
| TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("QuantizedConcat") |
| .Input("concat_dim: int32") |
| .Input("values: N * T") |
| .Input("input_mins: N * float32") |
| .Input("input_maxes: N * float32") |
| .Output("output: T") |
| .Output("output_min: float") |
| .Output("output_max: float") |
| .Attr("N: int >= 2") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| const int n = (c->num_inputs() - 1) / 3; |
| TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n)); |
| ShapeHandle unused; |
| for (int i = n + 1; i < c->num_inputs(); ++i) { |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); |
| } |
| c->set_output(1, c->Scalar()); |
| c->set_output(2, c->Scalar()); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("QuantizedReshape") |
| .Input("tensor: T") |
| .Input("shape: Tshape") |
| .Input("input_min: float") |
| .Input("input_max: float") |
| .Output("output: T") |
| .Output("output_min: float") |
| .Output("output_max: float") |
| .Attr("T: type") |
| .Attr("Tshape: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c)); |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
| c->set_output(1, c->Scalar()); |
| c->set_output(2, c->Scalar()); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("QuantizedInstanceNorm") |
| .Input("x: T") |
| .Input("x_min: float") |
| .Input("x_max: float") |
| .Output("y: T") |
| .Output("y_min: float") |
| .Output("y_max: float") |
| .Attr("T: quantizedtype") |
| .Attr("output_range_given: bool = false") |
| .Attr("given_y_min: float = 0") |
| .Attr("given_y_max: float = 0") |
| .Attr("variance_epsilon: float = 1e-5") |
| .Attr("min_separation: float = 1e-3") |
| .SetShapeFn([](shape_inference::InferenceContext* c) { |
| shape_inference::ShapeHandle unused; |
| // x should be a rank 4 tensor. |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused)); |
| // Assert x_min and x_max are scalars (rank 0). |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| // y has the same shape as x. |
| TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
| // y_min and y_max are scalars. |
| c->set_output(1, c->Scalar()); |
| c->set_output(2, c->Scalar()); |
| return Status::OK(); |
| }); |
| |
| namespace { |
| |
| Status ScatterNdShape(InferenceContext* c) { |
| ShapeHandle indices_shape; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); |
| ShapeHandle updates_shape; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); |
| ShapeHandle output_shape; |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); |
| |
| if (c->Value(c->NumElements(output_shape)) == 0 && |
| (c->Value(c->NumElements(indices_shape)) > 0 || |
| c->Value(c->NumElements(updates_shape)) > 0)) { |
| return errors::InvalidArgument( |
| "Indices and updates specified for empty output shape"); |
| } |
| |
| if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { |
| const int64 outer_dims = c->Rank(indices_shape) - 1; |
| const DimensionHandle ixdim = c->Dim(indices_shape, -1); |
| |
| // We can only do more validation if the last dimension of indices |
| // is a known value. |
| if (c->ValueKnown(ixdim)) { |
| int64 ix = c->Value(ixdim); |
| ShapeHandle unused; |
| ShapeHandle prefix_indices; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); |
| ShapeHandle prefix_updates; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); |
| |
| Status s = c->Merge(prefix_indices, prefix_updates, &unused); |
| if (!s.ok()) { |
| return errors::InvalidArgument( |
| "The outer ", outer_dims, |
| " dimensions of indices.shape=", c->DebugString(indices_shape), |
| " must match the outer ", outer_dims, |
| " dimensions of updates.shape=", c->DebugString(updates_shape), |
| ": ", s.error_message()); |
| } |
| |
| ShapeHandle suffix_output; |
| TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output)); |
| ShapeHandle suffix_updates; |
| TF_RETURN_IF_ERROR( |
| c->Subshape(updates_shape, outer_dims, &suffix_updates)); |
| s = c->Merge(suffix_output, suffix_updates, &unused); |
| if (!s.ok()) { |
| return errors::InvalidArgument( |
| "The inner ", c->Rank(output_shape) - ix, |
| " dimensions of output.shape=", c->DebugString(output_shape), |
| " must match the inner ", c->Rank(updates_shape) - outer_dims, |
| " dimensions of updates.shape=", c->DebugString(updates_shape), |
| ": ", s.error_message()); |
| } |
| } |
| } |
| |
| c->set_output(0, output_shape); |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| REGISTER_OP("UpperBound") |
| .Input("sorted_inputs: T") |
| .Input("values: T") |
| .Output("output: out_type") |
| .Attr("T: type") |
| .Attr("out_type: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); |
| c->set_output(0, c->input(1)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("LowerBound") |
| .Input("sorted_inputs: T") |
| .Input("values: T") |
| .Output("output: out_type") |
| .Attr("T: type") |
| .Attr("out_type: {int32, int64} = DT_INT32") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle unused_shape; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); |
| c->set_output(0, c->input(1)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("ScatterNd") |
| .Input("indices: Tindices") |
| .Input("updates: T") |
| .Input("shape: Tindices") |
| .Output("output: T") |
| .Attr("T: type") |
| .Attr("Tindices: {int32, int64}") |
| .SetShapeFn(ScatterNdShape); |
| |
| REGISTER_OP("ScatterNdNonAliasingAdd") |
| .Input("input: T") |
| .Input("indices: Tindices") |
| .Input("updates: T") |
| .Output("output: T") |
| .Attr("T: {numbertype, bool}") |
| .Attr("Tindices: {int32, int64}") |
| .SetShapeFn(shape_inference::ScatterNdUpdateShape); |
| |
| REGISTER_OP("FakeQuantWithMinMaxArgs") |
| .Attr("min: float = -6.0") |
| .Attr("max: float = 6.0") |
| .Attr("num_bits: int = 8") |
| .Attr("narrow_range: bool = false") |
| .Input("inputs: float") |
| .Output("outputs: float") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| REGISTER_OP("FakeQuantWithMinMaxArgsGradient") |
| .Attr("min: float = -6.0") |
| .Attr("max: float = 6.0") |
| .Attr("num_bits: int = 8") |
| .Attr("narrow_range: bool = false") |
| .Input("gradients: float") |
| .Input("inputs: float") |
| .Output("backprops: float") |
| .SetShapeFn(shape_inference::UnchangedShape); |
| |
| REGISTER_OP("FakeQuantWithMinMaxVars") |
| .Attr("num_bits: int = 8") |
| .Attr("narrow_range: bool = false") |
| .Input("inputs: float") |
| .Input("min: float") |
| .Input("max: float") |
| .Output("outputs: float") |
| .SetShapeFn([](InferenceContext* c) { |
| TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
| ShapeHandle unused; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("FakeQuantWithMinMaxVarsGradient") |
| .Attr("num_bits: int = 8") |
| .Attr("narrow_range: bool = false") |
| .Input("gradients: float") |
| .Input("inputs: float") |
| .Input("min: float") |
| .Input("max: float") |
| .Output("backprops_wrt_input: float") |
| .Output("backprop_wrt_min: float") |
| .Output("backprop_wrt_max: float") |
| .SetShapeFn([](InferenceContext* c) { |
| // gradients and inputs are same size. |
| ShapeHandle inputs; |
| TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs)); |
| |
| // min and max are scalars |
| ShapeHandle min_max; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max)); |
| TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max)); |
| |
| c->set_output(0, inputs); |
| c->set_output(1, min_max); |
| c->set_output(2, min_max); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel") |
| .Attr("num_bits: int = 8") |
| .Attr("narrow_range: bool = false") |
| .Input("inputs: float") |
| .Input("min: float") |
| .Input("max: float") |
| .Output("outputs: float") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle input, min, max; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min)); |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max)); |
| |
| DimensionHandle unused; |
| TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused)); |
| TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused)); |
| TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused)); |
| |
| c->set_output(0, input); |
| return Status::OK(); |
| }); |
| |
| REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient") |
| .Attr("num_bits: int = 8") |
| .Attr("narrow_range: bool = false") |
| .Input("gradients: float") |
| .Input("inputs: float") |
| .Input("min: float") |
| .Input("max: float") |
| .Output("backprops_wrt_input: float") |
| .Output("backprop_wrt_min: float") |
| .Output("backprop_wrt_max: float") |
| .SetShapeFn([](InferenceContext* c) { |
| ShapeHandle inputs; |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs)); |
| TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs)); |
| TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs)); |
| |
| ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1)); |
| |
| ShapeHandle min_max; |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max)); |
| TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max)); |
| TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max)); |
| |
| c->set_output(0, inputs); |
| c->set_output(1, min_max); |
| c->set_output(2, min_max); |
| return Status::OK(); |
| }); |
| |
| #ifdef INTEL_MKL |
| REGISTER_OP("_MklConcat") |
| .Input("concat_dim: int32") |
| .Input("values: N * T") |
| .Input("mkl_concat_dim: uint8") |
| .Input("mkl_values: N * uint8") |
| .Output("output: T") |
| .Output("mkl_output: uint8") |
| .Attr("N: int >= 2") |
| .Attr("T: type") |
| .SetShapeFn([](InferenceContext* c) { |
| return shape_inference::ConcatShape(c, c->num_inputs() - 3); |
| }) |
| .Doc(R"doc( |
| MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation. |
| |
| NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
| expected to invoke these operators. |
| )doc"); |
| #endif |
| |
| // Deprecated op registrations: |
| |
| // The following can be deleted after 10mar2017. |
| REGISTER_OP("BatchMatrixDiag") |
| .Input("diagonal: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Deprecated(14, "Use MatrixDiag") |
| .SetShapeFn(shape_inference::UnknownShape); |
| REGISTER_OP("BatchMatrixSetDiag") |
| .Input("input: T") |
| .Input("diagonal: T") |
| .Output("output: T") |
| .Attr("T: type") |
| .Deprecated(14, "Use MatrixSetDiag") |
| .SetShapeFn(shape_inference::UnknownShape); |
| REGISTER_OP("BatchMatrixDiagPart") |
| .Input("input: T") |
| .Output("diagonal: T") |
| .Attr("T: type") |
| .Deprecated(14, "Use MatrixDiagPart") |
| .SetShapeFn(shape_inference::UnknownShape); |
| REGISTER_OP("BatchMatrixBandPart") |
| .Input("input: T") |
| .Input("num_lower: int64") |
| .Input("num_upper: int64") |
| .Output("band: T") |
| .Attr("T: type") |
| .Deprecated(14, "Use MatrixBandPart") |
| .SetShapeFn(shape_inference::UnknownShape); |
| |
| } // namespace tensorflow |