blob: bfd29a7b1e7b9054a1c204da164b3a4cab5a289a [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
REGISTER_OP("_TPUCompileMlir")
.Attr("num_computations: int >= 0")
.Attr("mlir_module: string=\"\"")
.Attr("metadata: string")
.Attr("NumDynamicShapes: int >= 0")
// Do not try to optimize me away. We would like the compilation-op to be
// invoked for every step, and not be constant-folded away, in case the
// program is evicted from the compilation cache.
.SetIsStateful()
.Input("dynamic_shapes: NumDynamicShapes * int64")
.Output("compilation_status: string")
.Output("program: num_computations * string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int num_computations;
TF_RETURN_IF_ERROR(
GetNodeAttr(c->attrs(), "num_computations", &num_computations));
// Compilation status.
c->set_output(0, c->Scalar());
// Programs.
for (int i = 0; i < num_computations; ++i) {
c->set_output(i + 1, c->Vector(2));
}
return Status::OK();
})
.Doc(
R"(
Compiles a computations for execution on one or more TPU devices.
For the internal use of the distributed TPU compiler.
'mlir_module' is a serialized MLIR module with a `main` function that contains
target computation.
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
known statically at TPUReplication rewrite time.
'metadata' is a serialized TPUCompileMetadataProto describing the shapes and
types of the inputs to the computation, as well as a mapping onto the TPU pod
topology.
'program' output is a string key that is passed to the TPUExecute op and used to
look up the program in the compilation cache.
)");
REGISTER_OP("_TPUCompileMlirPlaceholderProgramKey")
.SetIsStateful()
.Output("program: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
})
.SetIsStateful()
.Doc(
R"(
Placeholder program key (compilation cache key) of a _TPUCompileMlir `program`.
This op can be used when certain rewrite passes materialize ops that require a
program key but the _TPUCompileMlir op has not been added yet. Subsequent
rewrite passes must replace this op with a _TPUCompileMlir op `program` output.
)");
REGISTER_OP("TPUCompile")
.Attr("num_computations: int >= 0")
.Attr("function: func")
.Attr("metadata: string")
.Attr("NumDynamicShapes: int >= 0")
.Attr("Tguaranteed_constants: list(type) >= 0")
// Do not try to optimize me away. We would like the compilation-op to be
// invoked for every step, and not be constant-folded away, in case the
// program is evicted from the compilation cache.
.SetIsStateful()
.Input("dynamic_shapes: NumDynamicShapes * int64")
.Input("guaranteed_constants: Tguaranteed_constants")
.Output("compilation_status: string")
.Output("program: num_computations * string")
.Output("may_modify_variables: num_computations * bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int num_computations;
TF_RETURN_IF_ERROR(
GetNodeAttr(c->attrs(), "num_computations", &num_computations));
// Compilation status.
c->set_output(0, c->Scalar());
// Programs.
for (int i = 0; i < num_computations; ++i) {
c->set_output(i + 1, c->Vector(2));
}
// May modify variables.
for (int i = 0; i < num_computations; ++i) {
c->set_output(num_computations + i + 1, c->Scalar());
}
return Status::OK();
});
REGISTER_OP("TPUCompileSucceededAssert")
.Input("compilation_status: string")
// Do not optimize me away. Read the comment on TPUCompileOp for more
// details.
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
} // namespace tensorflow