Merge pull request #22371 from aaroey:fix_zero_size_allocation
PiperOrigin-RevId: 213998222
diff --git a/RELEASE.md b/RELEASE.md
index 2f26623..20e1d92 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -94,7 +94,7 @@
## Breaking Changes
-* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites).
+* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [TensorFlow GPU support](https://www.tensorflow.org/install/gpu) and [Build TensorFlow from source](https://www.tensorflow.org/install/source).
* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake.
## Bug Fixes and Other Changes
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 49990b62..41b5b8f 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -29,15 +29,8 @@
namespace tensorflow {
namespace eager {
-// Information about a tensor.
-struct TapeTensor {
- int64 id; // Expected to be unique in the lifetime of this process.
- DataType dtype;
- TensorShape shape;
-};
-
// Represents an entry in the tape.
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
@@ -57,8 +50,8 @@
using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
+template <typename BackwardFunction, typename TapeTensor>
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@@ -79,7 +72,7 @@
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
// specialization, which is blocked by quite a few things needing to loop back
// into python now.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class VSpace {
public:
virtual ~VSpace() {}
@@ -93,10 +86,10 @@
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Returns a tensor of the right shape and dtype filled with zeros.
- virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
// Returns a Tensor which is filled with ones and like the input.
- virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
@@ -114,7 +107,7 @@
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class GradientTape {
public:
// If `persistent` is true, GradientTape will not eagerly delete backward
@@ -134,7 +127,7 @@
void Watch(int64 tensor_id);
void RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -146,17 +139,18 @@
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
- Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_id,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result);
+ Status ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_id,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result);
bool IsPersistent() const { return persistent_; }
private:
TensorTape tensor_tape_;
- OpTape<BackwardFunction> op_tape_;
+ OpTape<BackwardFunction, TapeTensor> op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -186,8 +180,8 @@
}
}
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
@@ -201,14 +195,15 @@
return false;
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
+ int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -229,16 +224,18 @@
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
+ tensor_tape_[o.GetID()] = op_id;
+ tensor_usage_[o.GetID()] = 1;
tensors.push_back(o);
}
- op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
- op_type, tensors, ids, backward_function, backward_function_deleter};
+ op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
+ op_type, std::move(tensors), ids, backward_function,
+ backward_function_deleter};
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
+ int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
@@ -261,7 +258,7 @@
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+ if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
@@ -304,9 +301,9 @@
namespace {
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct BackpropInitialState {
- OpTape<BackwardFunction> op_tape;
+ OpTape<BackwardFunction, TapeTensor> op_tape;
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
@@ -322,17 +319,17 @@
// If `persistent_tape` is false, op_tape is cleared and backwards functions
// not needed for gradient computation are deleted. Backwards functions that
// are needed, are copied and returned in BackpropInitialState.
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
+template <typename BackwardFunction, typename TapeTensor>
+BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
- bool persistent_tape) {
+ OpTape<BackwardFunction, TapeTensor>* op_tape,
+ const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
tensor_stack.push_back(t);
}
- BackpropInitialState<BackwardFunction> result;
+ BackpropInitialState<BackwardFunction, TapeTensor> result;
while (!tensor_stack.empty()) {
int64 tensor_id = tensor_stack.back();
tensor_stack.pop_back();
@@ -383,9 +380,9 @@
return result;
}
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
std::vector<int64> InitialStack(
- const OpTape<BackwardFunction>& op_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
@@ -396,13 +393,13 @@
return result;
}
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients,
- const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status InitialGradients(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
+ gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) {
@@ -416,11 +413,10 @@
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
+ if (op_it->second.output_tensor_info[j].GetID() == id) {
found = true;
(*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
+ vspace.Ones(op_it->second.output_tensor_info[j]));
break;
}
}
@@ -469,16 +465,16 @@
constexpr int kMinAggregateCount = 4;
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
- const VSpace<Gradient, BackwardFunction>& vspace,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
- BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+ BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
@@ -522,7 +518,7 @@
out_gradients.reserve(trace.output_tensor_info.size());
bool any_gradient_nonzero = false;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
- const int64 id = trace.output_tensor_info[i].id;
+ const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
@@ -531,9 +527,7 @@
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
- out_gradients.push_back(
- vspace.Zeros(trace.output_tensor_info[i].shape,
- trace.output_tensor_info[i].dtype));
+ out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
}
} else {
any_gradient_nonzero = true;
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 8486b58..247236b 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@
session->extend_before_run = false;
}
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,9 +135,8 @@
return result;
}
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status) {
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 4bcb5bd..5cce840 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -54,16 +54,17 @@
void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns the empty string.
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// `output` if its a resource or variant tensor, or otherwise returns the empty
+// string.
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
+// CppShapeInferenceResult::HandleData proto. `output` should be a resource
+// or variant tensor.
// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
// because I couldn't get SWIG to work otherwise.
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status);
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 7a0932d..10fa33a 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
+ ":test_graph_tftop_k_test",
":tfcompile_test",
],
)
@@ -42,6 +43,7 @@
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:training",
@@ -66,6 +68,7 @@
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
+ "test_graph_tftop_k.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
# GPUs which might be present. This is important because builds may run
@@ -208,6 +211,17 @@
],
)
+tf_library(
+ name = "test_graph_tftop_k",
+ testonly = 1,
+ config = "test_graph_tftop_k.config.pbtxt",
+ cpp_class = "TopKComp",
+ graph = "test_graph_tftop_k.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@@ -226,6 +240,7 @@
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
+ ":test_graph_tftop_k",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df1..de135d7 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -142,6 +143,12 @@
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000..6b4ac2d
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x" }
+ shape {
+ dim { size: 5 }
+ }
+}
+fetch {
+ id { node_name: "values" }
+}
+fetch {
+ id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 7ac90fb..f10852c 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,6 +29,7 @@
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -448,6 +449,30 @@
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, TopK) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ TopKComp fn;
+
+ fn.set_thread_pool(&device);
+ // x = [4, 1, 4, 4, 3]
+ fn.arg0(0) = 4;
+ fn.arg0(1) = 1;
+ fn.arg0(2) = 4;
+ fn.arg0(3) = 4;
+ fn.arg0(4) = 3;
+
+ EXPECT_TRUE(fn.Run());
+ EXPECT_EQ(fn.error_msg(), "");
+ const int32 expected_values[] = {4, 4};
+ const int32 expected_indices[] = {0, 2};
+ EXPECT_EQ(expected_values[0], fn.result0(0));
+ EXPECT_EQ(expected_values[1], fn.result0(1));
+ EXPECT_EQ(expected_indices[0], fn.result1(0));
+ EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe..859c84b 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 1001c57..4e18472 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -26,6 +26,7 @@
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -50,7 +51,7 @@
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@@ -62,7 +63,7 @@
visibility = ["//visibility:public"],
deps = if_cuda([
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
@@ -76,7 +77,7 @@
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -94,7 +95,7 @@
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
@@ -111,7 +112,7 @@
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
@@ -280,7 +281,7 @@
deps = [
":common",
":compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -341,7 +342,7 @@
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -359,7 +360,7 @@
cc_library(
name = "compilation_passes",
srcs = [
- "build_xla_launch_ops_pass.cc",
+ "build_xla_ops_pass.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
@@ -369,7 +370,7 @@
"partially_decluster_pass.cc",
],
hdrs = [
- "build_xla_launch_ops_pass.h",
+ "build_xla_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
@@ -459,7 +460,7 @@
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -493,7 +494,7 @@
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -524,7 +525,7 @@
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -628,6 +629,15 @@
],
)
+tf_custom_op_py_library(
+ name = "xla_ops_py",
+ kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
+ visibility = [
+ ":friends",
+ ],
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
deleted file mode 100644
index b17ff58..0000000
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-static Status BuildLaunchNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
- Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("XlaLaunch");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
-}
-
-static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
- VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
-
- int num_constant_args, num_resource_args;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
-
- if (num_constant_args < 0 || num_resource_args < 0 ||
- num_constant_args + num_resource_args > node->num_inputs()) {
- return errors::InvalidArgument(
- "Invalid number of constant/resource arguments to XLA kernel.");
- }
- const int num_nonconst_args =
- node->num_inputs() - num_constant_args - num_resource_args;
-
- DataTypeVector const_dtypes(node->input_types().begin(),
- node->input_types().begin() + num_constant_args);
- DataTypeVector arg_dtypes(
- node->input_types().begin() + num_constant_args,
- node->input_types().begin() + num_constant_args + num_nonconst_args);
-
- // Build a XlaLaunch operator to execute the function body.
- Node* launch_node;
- TF_RETURN_IF_ERROR(BuildLaunchNode(
- graph->NewName(node->name()), node->type_string(), node->def().attr(),
- node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
- node->output_types(), graph, &launch_node));
- launch_node->set_assigned_device_name(node->assigned_device_name());
-
- // Copy incoming edges to the launch node.
- for (const Edge* edge : node->in_edges()) {
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(edge->src(), launch_node);
- } else {
- graph->AddEdge(edge->src(), edge->src_output(), launch_node,
- edge->dst_input());
- }
- }
-
- // Copy outgoing edges to the launch node.
- std::vector<const Edge*> out_edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : out_edges) {
- Node* dst = edge->dst();
- int src_output = edge->src_output();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(launch_node, dst);
- } else {
- graph->AddEdge(launch_node, src_output, dst, dst_input);
- }
- }
- graph->RemoveNode(node);
-
- return Status::OK();
-}
-
-Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
- Graph* graph = options.graph->get();
-
- for (Node* n : graph->op_nodes()) {
- // In all cases, only try to compile computational nodes.
- if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
- continue;
- }
-
- // Only compile nodes that are marked for compilation by the
- // compilation-marking pass (via 'attr_name').
- if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
- }
- }
-
- if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
- options.flib_def);
- }
- return Status::OK();
-}
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
new file mode 100644
index 0000000..a6086f3
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+static Status BuildXlaCompileNode(
+ const string& nodename, const string& function_name,
+ const AttrValueMap& function_attr, const string& device_name,
+ const DataTypeVector& constant_dtypes, int num_resources,
+ const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaCompile");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Nresources", num_resources, &def);
+ NameAttrList function;
+ function.set_name(function_name);
+ *function.mutable_attr() = function_attr;
+ AddNodeAttr("function", function, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status BuildXlaRunNode(const string& nodename, const string& device_name,
+ const DataTypeVector& constant_dtypes,
+ const DataTypeVector& arg_dtypes,
+ const DataTypeVector& result_dtypes, Graph* graph,
+ Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaRun");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Tresults", result_dtypes, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status GetXlaAttrs(Node* node, int* num_constant_args,
+ int* num_resource_args, DataTypeVector* const_dtypes,
+ DataTypeVector* arg_dtypes) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+
+ if (*num_constant_args < 0 || *num_resource_args < 0 ||
+ *num_constant_args + *num_resource_args > node->num_inputs()) {
+ return errors::InvalidArgument(
+ "Invalid number of constant/resource arguments to XLA kernel.");
+ }
+
+ const int num_nonconst_args =
+ node->num_inputs() - *num_constant_args - *num_resource_args;
+
+ const DataTypeVector& input_types = node->input_types();
+ std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
+ std::back_inserter(*const_dtypes));
+ std::copy(input_types.begin() + *num_constant_args,
+ input_types.begin() + *num_constant_args + num_nonconst_args,
+ std::back_inserter(*arg_dtypes));
+ return Status::OK();
+}
+
+static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node) {
+ for (const Edge* edge : old_node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(edge->src(), new_node);
+ } else {
+ g->AddEdge(edge->src(), edge->src_output(), new_node, edge->dst_input());
+ }
+ }
+}
+
+static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ Node* dst = edge->dst();
+ int src_output = edge->src_output();
+ int dst_input = edge->dst_input();
+ g->RemoveEdge(edge);
+
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(new_node, dst);
+ } else {
+ g->AddEdge(new_node, src_output, dst, dst_input);
+ }
+ }
+}
+
+static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
+ int num_constant_args, num_resource_args;
+ DataTypeVector const_dtypes;
+ DataTypeVector arg_dtypes;
+
+ TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
+ &const_dtypes, &arg_dtypes));
+
+ Node *compile_node, *run_node;
+
+ TF_RETURN_IF_ERROR(BuildXlaCompileNode(
+ n->name(), n->type_string(), n->def().attr(), n->requested_device(),
+ const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+
+ DataTypeVector arg_dtypes_with_resources = arg_dtypes;
+ for (int i = 0; i < num_resource_args; i++) {
+ arg_dtypes_with_resources.push_back(DT_RESOURCE);
+ }
+
+ TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
+ const_dtypes, arg_dtypes_with_resources,
+ n->output_types(), g, &run_node));
+
+ compile_node->set_assigned_device_name(n->assigned_device_name());
+ run_node->set_assigned_device_name(n->assigned_device_name());
+
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node);
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+
+ // The compilation_key output.
+ g->AddEdge(compile_node, 0, run_node, n->num_inputs());
+
+ MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+
+ for (Node* n : graph->op_nodes()) {
+ // In all cases, only try to compile computational nodes.
+ if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+ continue;
+ }
+
+ // Only compile nodes that are marked for compilation by the
+ // compilation-marking pass (via 'attr_name').
+ if (IsXlaCompiledKernel(*n)) {
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
+ }
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
similarity index 71%
rename from tensorflow/compiler/jit/build_xla_launch_ops_pass.h
rename to tensorflow/compiler/jit/build_xla_ops_pass.h
index 1dfea93..1dd38fa 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -13,19 +13,21 @@
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
+// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+class BuildXlaOpsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 56b034a..6f1ff85 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -16,7 +16,7 @@
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 3770eea..085c0e5 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -55,6 +55,6 @@
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
- BuildXlaLaunchOpsPass);
+ BuildXlaOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 253a5d2..0839f1c 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -7,9 +7,9 @@
)
cc_library(
- name = "xla_launch_op",
- srcs = ["xla_launch_op.cc"],
- hdrs = ["xla_launch_op.h"],
+ name = "xla_ops",
+ srcs = ["xla_ops.cc"],
+ hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
@@ -26,6 +26,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
deleted file mode 100644
index b6f2f63..0000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
-
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function)
- : OpKernel(ctx),
- constants_(constants),
- resources_(resources),
- device_type_(ctx->device_type()),
- function_(function) {
- if (device_type_ == DeviceType(DEVICE_CPU)) {
- platform_id_ = se::host::kHostPlatformId;
- } else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = ctx->device()
- ->tensorflow_gpu_device_info()
- ->stream->parent()
- ->platform()
- ->id();
- } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
- use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
- platform_id_ = xla_device_metadata_->platform()->id();
- }
-}
-
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
- if (xla_device_metadata_) {
- *cache = new XlaCompilationCache(xla_device_metadata_->client(),
- xla_device_metadata_->jit_device_type());
- return Status::OK();
- }
-
- auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
- if (!platform.ok()) {
- return platform.status();
- }
- xla::LocalClientOptions client_options;
- client_options.set_platform(platform.ValueOrDie());
- client_options.set_intra_op_parallelism_threads(
- ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
- auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
- if (!client.ok()) {
- return client.status();
- }
- const XlaOpRegistry::DeviceRegistration* registration;
- if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
- ®istration)) {
- return errors::InvalidArgument("No JIT device registered for ",
- device_type_.type());
- }
- *cache = new XlaCompilationCache(
- client.ValueOrDie(), DeviceType(registration->compilation_device_name));
- return Status::OK();
-}
-
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOpBase::Compute "
- << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
- // We store information about the JIT-compiled XLA computation
- // in the ResourceMgr.
- ResourceMgr* rm = ctx->resource_manager();
- OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
-
- se::Stream* stream =
- ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
-
- XlaCompilationCache* cache;
- OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
- rm->default_container(), "xla_cache", &cache,
- [this, ctx](XlaCompilationCache** cache) {
- return BuildCompilationCache(ctx, cache);
- }));
- // Hold the reference to the JIT during evaluation. (We could probably
- // free it sooner because the ResourceMgr will retain a reference, but
- // this is more obviously correct.)
- core::ScopedUnref cache_ref(cache);
-
- std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, resources_);
-
- xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
-
- XlaAllocator local_xla_allocator(client->backend().platform(),
- ctx->device()->GetAllocator({}));
- xla::DeviceMemoryAllocator* xla_allocator;
- // If we are on an XlaDevice, use the underlying XLA platform's allocator
- // directly. We could use the StreamExecutor's allocator which may
- // theoretically be more correct, but XLA returns a nice OOM message in a
- // Status and StreamExecutor does not.
- //
- // Importantly we can't use ctx->device()->GetAllocator() as the allocator
- // (which local_xla_allocator above uses) as on an XlaDevice, this is a
- // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
- // real allocator to allocate real buffers.
- if (xla_device_metadata_) {
- xla_allocator = client->backend().memory_allocator();
- } else {
- xla_allocator = &local_xla_allocator;
- }
-
- XlaCompiler::Options options;
- options.client = client;
- if (ctx->op_device_context() != nullptr) {
- options.device_ordinal =
- ctx->op_device_context()->stream()->parent()->device_ordinal();
- }
- options.device_type = cache->device_type();
- options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
- options.graph_def_version = ctx->function_library()->graph_def_version();
- options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
- options.device_allocator = xla_allocator;
- if (xla_device_metadata_) {
- options.shape_representation_fn =
- xla_device_metadata_->shape_representation_fn();
- }
-
- const XlaCompiler::CompilationResult* kernel;
- xla::LocalExecutable* executable;
-
- std::map<int, Tensor> constant_args;
- for (int i : constants_) {
- constant_args.insert({i, ctx->input(i)});
- }
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true;
- // If we resolve constants we never emit them on the device, meaning that if
- // they are needed by a following computation the host has to transfer
- // them. Not resolving constants is expected to be faster than resolving
- // constants.
- compile_options.resolve_compile_time_constants = true;
- // Optimization: where possible, have the computation return a naked array
- // rather than a one-element tuple.
- compile_options.always_return_tuple = false;
-
- OP_REQUIRES_OK(
- ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, compile_options));
-
- VLOG(1) << "Executing XLA Computation...";
-
- XlaComputationLaunchContext launch_context(
- client, xla_allocator,
- /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
- use_multiple_streams_);
- launch_context.PopulateInputs(ctx, kernel, variables);
-
- // Execute the computation.
- VLOG(2) << "Executing computation.";
- xla::ExecutableRunOptions run_options;
- run_options.set_stream(stream);
- run_options.set_allocator(xla_allocator);
- run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(GetXLARandomSeed());
- Env* env = Env::Default();
- auto start_time = env->NowMicros();
-
- auto run_result = executable->Run(launch_context.arguments(), run_options);
- OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
- auto elapsed = env->NowMicros() - start_time;
- VLOG(2) << "Elapsed time: " << elapsed << "us";
-
- OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
- VLOG(1) << "Done";
-}
-
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return RET; \
- } \
- } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
- std::vector<int> constants(constant_types.size());
- std::iota(constants.begin(), constants.end(), 0);
- return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
-
- DataTypeVector arg_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Targs", &arg_types));
-
- int num_resources;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Nresources", &num_resources));
-
- std::vector<int> resources(num_resources);
- std::iota(resources.begin(), resources.end(),
- constant_types.size() + arg_types.size());
- return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
- const NameAttrList* func;
- OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
- return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-} // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
- FunctionAttr(ctx)) {}
-
-XlaLocalLaunchOp::~XlaLocalLaunchOp() {
- VLOG(1) << "XlaLocalLaunchOp destroyed";
-}
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
- .Device(DEVICE_GPU)
- .HostMemory("constants")
- .HostMemory("resources"),
- XlaLocalLaunchOp);
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
deleted file mode 100644
index e0f10e9..0000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
- XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function);
- XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
- XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
- ~XlaLocalLaunchBase() override = default;
-
- void Compute(OpKernelContext* ctx) override;
-
- protected:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache);
-
- // Indexes of compile-time constant inputs
- std::vector<int> constants_;
- // Indexes of resource inputs
- std::vector<int> resources_;
-
- DeviceType device_type_;
- NameAttrList function_;
- se::Platform::Id platform_id_ = nullptr;
- bool use_multiple_streams_ = false;
- const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
-};
-
-// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
-// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
-// responsible for handling interactions with the TensorFlow executor.
-// Once all inputs are present, and their shapes are known, the op can
-// use a 'XlaCompilationCache' to compile and execute code which is specific
-// to the shapes of input Tensors.
-// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
-// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
-// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
- public:
- explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
- ~XlaLocalLaunchOp() override;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
new file mode 100644
index 0000000..c483841
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -0,0 +1,488 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status PlatformInfoFromContext(OpKernelConstruction* ctx,
+ XlaPlatformInfo* result) {
+ DeviceType device_type = ctx->device_type();
+ se::Platform::Id platform_id = nullptr;
+ const XlaDevice::Metadata* xla_device_metadata = nullptr;
+ std::unique_ptr<XlaAllocator> xla_allocator;
+ xla::DeviceMemoryAllocator* device_allocator = nullptr;
+
+ if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+ platform_id = se::host::kHostPlatformId;
+ } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
+ platform_id = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+ // If we are on an XlaDevice, use the underlying XLA platform's allocator
+ // directly. We could use the StreamExecutor's allocator which may
+ // theoretically be more correct, but XLA returns a nice OOM message in a
+ // Status and StreamExecutor does not.
+ //
+ // Importantly we can't use ctx->device()->GetAllocator() as the allocator
+ // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
+ // allocator that returns XlaTensor objects. The XlaCompiler needs a real
+ // allocator to allocate real buffers.
+
+ platform_id = xla_device_metadata->platform()->id();
+ device_allocator =
+ xla_device_metadata->client()->backend().memory_allocator();
+ }
+
+ if (!device_allocator) {
+ TF_ASSIGN_OR_RETURN(se::Platform* const platform,
+ se::MultiPlatformManager::PlatformWithId(platform_id));
+ xla_allocator = absl::make_unique<XlaAllocator>(
+ platform, ctx->device()->GetAllocator({}));
+ }
+
+ *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
+ std::move(xla_allocator), device_allocator);
+
+ return Status::OK();
+}
+
+// A closure describing how to run a compiled version of a TensorFlow function.
+//
+// It may seem unusual to stick the resource variable snapshots in this class.
+// This is necessary: we need to use the snapshots observed by the compiler as
+// the initial values for the resource variables (and cannot snapshot them again
+// during execution) because otherwise we risk observing a different snapshot
+// with shapes different from what we compiled for.
+class XlaExecutableClosure {
+ public:
+ explicit XlaExecutableClosure(
+ xla::LocalClient* client, xla::LocalExecutable* executable,
+ const XlaCompiler::CompilationResult* compilation_result,
+ std::map<int, OptionalTensor> resource_var_snapshots)
+ : client_(client),
+ executable_(executable),
+ compilation_result_(compilation_result),
+ resource_var_snapshots_(std::move(resource_var_snapshots)) {}
+
+ XlaExecutableClosure(XlaExecutableClosure&&) = default;
+ XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
+
+ xla::LocalClient* client() const { return client_; }
+ xla::LocalExecutable* executable() const { return executable_; }
+ const XlaCompiler::CompilationResult* compilation_result() const {
+ return compilation_result_;
+ }
+ const std::map<int, OptionalTensor>& resource_var_snapshots() const {
+ return resource_var_snapshots_;
+ }
+
+ private:
+ xla::LocalClient* client_;
+ xla::LocalExecutable* executable_;
+ const XlaCompiler::CompilationResult* compilation_result_;
+ std::map<int, OptionalTensor> resource_var_snapshots_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
+};
+
+// This maintains a mapping from a globally unique ID to XlaExecutableClosure
+// instances.
+class XlaExecutableClosureStore {
+ public:
+ XlaExecutableClosureStore() : key_counter_(0) {}
+
+ using KeyT = string;
+
+ KeyT Produce(XlaExecutableClosure result) {
+ mutex_lock l(mutex_);
+ KeyT key = absl::StrCat(key_counter_++);
+ bool insert_successful = closures_.emplace(key, std::move(result)).second;
+ DCHECK(insert_successful);
+ (void)insert_successful;
+ return key;
+ }
+
+ XlaExecutableClosure Consume(const KeyT& key) {
+ mutex_lock l(mutex_);
+ auto it = closures_.find(key);
+ DCHECK(it != closures_.end());
+ XlaExecutableClosure value = std::move(it->second);
+ closures_.erase(it);
+ return value;
+ }
+
+ static XlaExecutableClosureStore* Global() {
+ static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
+ return instance;
+ }
+
+ private:
+ mutex mutex_;
+ int64 key_counter_ GUARDED_BY(mutex_);
+ gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
+};
+
+} // namespace
+
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ function_(function) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+static Status BuildCompilationCache(OpKernelContext* ctx,
+ const XlaPlatformInfo& platform_info,
+ XlaCompilationCache** cache) {
+ if (platform_info.xla_device_metadata()) {
+ *cache = new XlaCompilationCache(
+ platform_info.xla_device_metadata()->client(),
+ platform_info.xla_device_metadata()->jit_device_type());
+ return Status::OK();
+ }
+
+ auto platform =
+ se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
+ if (!platform.ok()) {
+ return platform.status();
+ }
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform.ValueOrDie());
+ client_options.set_intra_op_parallelism_threads(
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
+ if (!client.ok()) {
+ return client.status();
+ }
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
+ ®istration)) {
+ return errors::InvalidArgument("No JIT device registered for ",
+ platform_info.device_type().type());
+ }
+ *cache = new XlaCompilationCache(
+ client.ValueOrDie(), DeviceType(registration->compilation_device_name));
+ return Status::OK();
+}
+
+static Status CompileToLocalExecutable(
+ OpKernelContext* ctx, const NameAttrList& function,
+ const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
+ absl::Span<const int> constants, xla::LocalClient** client,
+ std::map<int, OptionalTensor>* variables,
+ const XlaCompiler::CompilationResult** kernel,
+ xla::LocalExecutable** executable) {
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ if (!rm) {
+ return errors::Internal("No resource manager.");
+ }
+
+ XlaCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_cache", &cache,
+ [&](XlaCompilationCache** cache) {
+ return BuildCompilationCache(ctx, platform_info, cache);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref cache_ref(cache);
+
+ *variables = SnapshotResourceVariables(ctx, resources);
+ *client = static_cast<xla::LocalClient*>(cache->client());
+
+ XlaCompiler::Options options;
+ options.client = *client;
+ if (ctx->op_device_context() != nullptr) {
+ options.device_ordinal =
+ ctx->op_device_context()->stream()->parent()->device_ordinal();
+ }
+ options.device_type = cache->device_type();
+ options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ options.graph_def_version = ctx->function_library()->graph_def_version();
+ options.allow_cpu_custom_calls =
+ (platform_info.platform_id() == se::host::kHostPlatformId);
+ options.device_allocator = platform_info.allocator();
+ if (platform_info.xla_device_metadata()) {
+ options.shape_representation_fn =
+ platform_info.xla_device_metadata()->shape_representation_fn();
+ }
+
+ std::map<int, Tensor> constant_args;
+ for (int i : constants) {
+ constant_args.insert({i, ctx->input(i)});
+ }
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.is_entry_computation = true;
+ // If we resolve constants we never emit them on the device, meaning that if
+ // they are needed by a following computation the host has to transfer
+ // them. Not resolving constants is expected to be faster than resolving
+ // constants.
+ compile_options.resolve_compile_time_constants = true;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
+
+ return cache->Compile(options, function, constant_args, *variables, ctx,
+ kernel, executable, compile_options);
+}
+
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+ << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
+
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ XlaComputationLaunchContext launch_context(
+ client, platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, kernel, variables);
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result = executable->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie()));
+ VLOG(1) << "Done";
+}
+
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+ VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx),
+ constants_(ConstantsVector(ctx)),
+ resources_(ResourcesVector(ctx)),
+ function_(FunctionAttr(ctx)) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaCompileOp::Compute(OpKernelContext* ctx) {
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
+ // if it didn't have to compile the cluster because of a compilation-cache
+ // hit. This is because we at least need new snapshots of the resource
+ // variables.
+ XlaExecutableClosureStore::KeyT key =
+ XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
+ client, executable, kernel, std::move(variables)));
+
+ Allocator* cpu_allocator = [&] {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ return ctx->device()->GetAllocator(host_alloc_attrs);
+ }();
+
+ Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
+ compilation_key.flat<string>()(0) = key;
+
+ Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
+ compilation_successful.flat<bool>()(0) = true;
+
+ ctx->set_output(0, compilation_key);
+ ctx->set_output(1, compilation_successful);
+}
+
+XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaRunOp::Compute(OpKernelContext* ctx) {
+ Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
+ const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+
+ XlaExecutableClosure closure =
+ XlaExecutableClosureStore::Global()->Consume(key);
+
+ XlaComputationLaunchContext launch_context(
+ closure.client(), platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, closure.compilation_result(),
+ closure.resource_var_snapshots());
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result =
+ closure.executable()->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
+
+ OP_REQUIRES_OK(
+ ctx, launch_context.PopulateOutputs(ctx, closure.compilation_result(),
+ run_result.ConsumeValueOrDie()));
+}
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("_XlaRun").Device(DEVICE_GPU).HostMemory("constants"), XlaRunOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h
new file mode 100644
index 0000000..489d26e
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+// Holds some information about the platform on which an
+// XlaLaunch/_XlaCompile/_XlaRun op must run on.
+class XlaPlatformInfo {
+ public:
+ XlaPlatformInfo() : device_type_("") {}
+ explicit XlaPlatformInfo(const DeviceType device_type,
+ se::Platform::Id platform_id,
+ const XlaDevice::Metadata* xla_device_metadata,
+ std::unique_ptr<XlaAllocator> xla_allocator,
+ xla::DeviceMemoryAllocator* device_allocator)
+ : device_type_(device_type),
+ platform_id_(platform_id),
+ xla_device_metadata_(xla_device_metadata),
+ xla_allocator_(std::move(xla_allocator)),
+ device_allocator_(device_allocator) {
+ CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
+ }
+
+ XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
+
+ bool UseMultipleStreams() const {
+ return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
+ }
+
+ xla::DeviceMemoryAllocator* allocator() const {
+ return device_allocator_ ? device_allocator_ : xla_allocator_.get();
+ }
+ DeviceType device_type() const { return device_type_; }
+
+ // This is equal to xla_device_metadata()->platform()->id() if
+ // xla_device_metadata() is not nullptr.
+ se::Platform::Id platform_id() const { return platform_id_; }
+
+ // This may be null if the op this XlaPlatformInfo is for was not placed on an
+ // XLA device.
+ const XlaDevice::Metadata* xla_device_metadata() const {
+ return xla_device_metadata_;
+ }
+ bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
+
+ private:
+ DeviceType device_type_;
+ se::Platform::Id platform_id_;
+
+ // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
+ // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
+ // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
+ const XlaDevice::Metadata* xla_device_metadata_;
+
+ // If the op associated with this XlaPlatformInfo is placed on an XLA device
+ // then device_allocator_ is the xla::Backend's memory allocator and
+ // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
+ // then device_allocator_ is null and xla_allocator_ points to an appropriate
+ // XlaAllocator instance.
+ std::unique_ptr<XlaAllocator> xla_allocator_;
+ xla::DeviceMemoryAllocator* device_allocator_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
+};
+
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+ XlaPlatformInfo platform_info_;
+};
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
+// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
+// memory.
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+ public:
+ explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+ ~XlaLocalLaunchOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+class XlaCompileOp : public OpKernel {
+ public:
+ explicit XlaCompileOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+
+ XlaPlatformInfo platform_info_;
+};
+
+class XlaRunOp : public OpKernel {
+ public:
+ explicit XlaRunOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ XlaPlatformInfo platform_info_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 1eaedbf..133d982 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -982,6 +982,11 @@
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
+ if (flags->tf_xla_clustering_debug) {
+ dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph,
+ options.flib_def);
+ }
+
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index 13804c6..f722245 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -4,9 +4,17 @@
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
cc_library(
name = "xla_ops",
srcs = ["xla_ops.cc"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
+
+tf_gen_op_wrapper_py(
+ name = "xla_ops_wrapper_py",
+ out = "xla_ops.py",
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 1a29c3c..6b4cdaa 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -51,4 +51,47 @@
"Operator that connects the output of an XLA computation to other "
"consumer graph nodes.");
+REGISTER_OP("_XlaCompile")
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Input("resources: Nresources * resource")
+ .Attr("Nresources: int >= 0")
+ .Output("key: string")
+ .Output("compilation_successful: bool")
+ .Attr("function: func")
+ // The compilation cache is stateful.
+ .SetIsStateful()
+ .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
+
+Compiles a TensorFlow function into an XLA LocalExecutable and returns a key
+that _XlaRun can use to look up the LocalExecutable and execute it.
+
+key: A key that can be used to look up the local executable compiled by the
+ node and associated metadata.
+
+compilation_successful: True iff the compilation was successful. Always true
+for now.
+)");
+
+REGISTER_OP("_XlaRun")
+ // TODO(sanjoy): We don't need constants and Tconstants and they should be
+ // removed.
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Output("results: Tresults")
+ .Attr("Tresults: list(type) >= 0")
+ .Input("key: string")
+ // XLA random-number generation ops are stateful.
+ // TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
+ .SetIsStateful()
+ .Doc(R"(XLA Run Op. For use by the XLA JIT only.
+
+Executes a TensorFlow function previously compiled into a LocalExecutable by an
+_XlaCompile op.
+)");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 1afc305..003c1d8 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -16,7 +16,7 @@
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
@@ -65,11 +65,14 @@
// Kernel registrations
-constexpr std::array<DataType, 9> kAllXlaCpuTypes = {
- {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 51797de..32fce2b 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -434,6 +434,16 @@
return status;
}
+void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
+ mutex_lock lock(mu_);
+ sync_on_completion_ = sync_on_completion;
+}
+
+bool XlaDevice::RequiresSyncOnCompletion() const {
+ mutex_lock lock(mu_);
+ return sync_on_completion_;
+}
+
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 92891ff..0f06b3f 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -151,6 +151,12 @@
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
+ // Instructs this XlaDevice to return 'sync_on_completion' for
+ // RequiresSyncOnCompletion().
+ void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
+
+ bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@@ -165,7 +171,7 @@
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
- mutex mu_;
+ mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
@@ -207,6 +213,10 @@
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+ // True if the device requires XlaDevice::Sync to be called on completion
+ // regardless of status.
+ bool sync_on_completion_ GUARDED_BY(mu_) = false;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 49c8582..6392439 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -65,6 +65,17 @@
.HostMemory("resources"), \
KERNEL);
+#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \
+ .Device(DEVICE) \
+ .HostMemory("constants") \
+ .HostMemory("resources"), \
+ KERNEL);
+
+#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_XlaRun").Device(DEVICE).HostMemory("constants"), KERNEL);
+
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 4cf5565..6097955 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,7 @@
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -74,11 +74,14 @@
// Kernel registrations
-constexpr std::array<DataType, 10> kAllXlaGpuTypes = {
- {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4574559..19e681a 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -15,7 +15,7 @@
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,6 +72,10 @@
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
+ kExecAllTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index affeab4..07a93e9 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -42,7 +42,7 @@
} // anonymous namespace
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables) {
+ OpKernelContext* ctx, absl::Span<const int> variables) {
std::map<int, OptionalTensor> snapshot;
for (int i : variables) {
Var* variable = nullptr;
@@ -275,6 +275,8 @@
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
+ TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
+ << "Invalid input for outputs " << i;
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 7ac275f..fa7a5e5 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -29,6 +29,7 @@
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class XlaAllocator;
@@ -43,7 +44,7 @@
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables);
+ OpKernelContext* ctx, absl::Span<const int> variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 97ed554..3cf74fa 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -978,7 +978,7 @@
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
- tags = ["noasan"], # times out, http://b/78599043
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -1198,6 +1198,19 @@
)
tf_xla_py_test(
+ name = "quantized_ops_test",
+ size = "small",
+ srcs = ["quantized_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 1147933..1d3979b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -2,6 +2,10 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
def all_backends():
b = ["cpu"] + plugins.keys()
@@ -58,14 +62,14 @@
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
- backend_tags += ["requires-gpu-sm35"]
+ backend_tags += tf_cuda_tests_tags()
elif backend in plugins:
backend_args += [
"--test_device=" + plugins[backend]["device"],
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 0af74c2..9390870 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -45,17 +45,21 @@
return any([substr in x for x in labels])
-def XlaLaunchOpCount(labels):
- """Count how many XlaLaunch labels are present."""
- return sum("XlaLaunch(" in x for x in labels)
-
-
class DenseLayerTest(test.TestCase):
+ def countXlaOps(self, labels):
+ """Count how many XlaCompile/XlaRun labels are present."""
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+ return xla_run_count
+
+
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
- Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
+ Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
+ auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = (
@@ -77,14 +81,14 @@
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
- XlaLaunch op by XLA.
+ XlaCompile/XlaRun op pair by XLA.
"""
with self.cached_session() as sess:
@@ -101,7 +105,7 @@
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
# No need to check whether ListDiff is compiled or not because ListDiff op
# is not used when input tensor shape is fully defined.
@@ -111,7 +115,8 @@
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
- cluster, causing dense layer to be split into TWO XlaLaunch ops.
+ cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
+ pairs.
"""
with self.cached_session() as sess:
@@ -128,7 +133,7 @@
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(2, XlaLaunchOpCount(labels))
+ self.assertEqual(2, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95d..a38e1ed 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@
indices_tf = constant_op.constant(indices)
gather_t = array_ops.gather(params, indices_tf)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- np_val = params_np[indices]
+ np_val = constant_op.constant(params_np[indices])
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
@@ -65,7 +65,8 @@
indices = constant_op.constant(2)
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, 2, axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, 2, axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@
indices = constant_op.constant([0, 1, 0, 2])
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@
params: params_np,
indices: indices_np
})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testHigherRank(self):
@@ -119,7 +122,8 @@
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
gather_value = sess.run(gather, feed_dict={tf_params: params})
- gather_np = np.take(params, indices, axis=axis)
+ gather_np = constant_op.constant(
+ np.take(params, indices, axis=axis), dtype)
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 6fe5a66..bbe746e 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -605,10 +605,6 @@
class NonMaxSuppressionTest(xla_test.XLATestCase):
def testNMS128From1024(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
num_boxes = 1024
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@@ -644,10 +640,6 @@
self.assertEqual(indices_tf.size, max_output_size)
def testNMS3From6Boxes(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
# Three boxes are selected based on IOU.
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@@ -693,10 +685,6 @@
# Three boxes are selected based on IOU.
# One is filtered out by score threshold.
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0839fb1..de68ff0 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -77,11 +77,11 @@
return any([substr in x for x in labels])
-def MetadataHasXlaLaunch(run_metadata):
- """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
+def MetadataHasXlaOp(run_metadata):
+ """Returns true if there are XlaRun kernels in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
- return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
+ return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
class JitLaunchTest(test.TestCase):
@@ -90,9 +90,10 @@
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
- # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
- # actually ran. However, it is sometimes possible for XlaLaunch ops to be
- # constant-folded away, so the check is optional.
+ #
+ # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
+ # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
+ # ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
@@ -115,7 +116,7 @@
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
@@ -149,10 +150,10 @@
y = math_ops.add(x, x)
return y, y
- # Exercises compling a function (say, Foo) which calls another
- # function (say, Bar) which is not inlined. When the compiler compiles
- # Foo, it needs to symbolic execute Bar correctly regardless whether
- # Bar is inlined or not.
+ # Exercises compiling a function (say, Foo) which calls another function
+ # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
+ # to symbolically execute Bar correctly regardless of whether Bar is inlined
+ # or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
@@ -259,7 +260,7 @@
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
@@ -289,7 +290,7 @@
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
@@ -313,7 +314,7 @@
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
@@ -331,7 +332,7 @@
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
@@ -356,7 +357,7 @@
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
@@ -441,14 +442,16 @@
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
- self.assertFalse(InLabels(labels, "XlaLaunch"))
+ self.assertFalse(InLabels(labels, "XlaCompile"))
+ self.assertFalse(InLabels(labels, "XlaRun"))
- # Compile the backprop. One XlaLaunch.
+ # Compile the backprop. One XlaCompile/XlaRun pair.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
- self.assertTrue(InLabels(labels, "XlaLaunch"))
+ self.assertTrue(InLabels(labels, "XlaCompile"))
+ self.assertTrue(InLabels(labels, "XlaRun"))
class ElementWiseFusionTest(test.TestCase):
@@ -482,9 +485,12 @@
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
- count = sum("XlaLaunch(" in x for x in labels)
- return output, count
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+
+ return output, xla_run_count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000..80c3385
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+ # Verify that quantized types can be clustered by XLA.
+ def testQuantizedTypeRoundtrip(self):
+ with self.cached_session() as session:
+ for dtype in self.quantized_tf_types:
+ in_values = np.array([1, 2, 3, 4, 5, 6])
+ expected = [[1, 2], [3, 4], [5, 6]]
+ with self.test_scope():
+ p = array_ops.placeholder(dtype=dtypes.int32)
+ x = math_ops.cast(p, dtype)
+ x = array_ops.reshape(x, [3, 2])
+
+ value = session.run(x, {p: in_values})
+ self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 41fe42a..36ef6ed 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -69,9 +69,8 @@
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
@@ -93,13 +92,13 @@
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- self._testRngIsNotConstant(rng, dtypes.float32)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ # TODO(b/34339814): make this test work with 16 bit float types.
+ for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
@@ -145,9 +144,6 @@
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self):
- # TODO(b/26783907): this test requires the CPU backend to implement sort.
- if self.device in ["XLA_CPU"]:
- return
with self.cached_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 51c04b5..dbf4beb 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,10 +48,6 @@
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
@@ -60,10 +56,6 @@
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testTopK(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -89,10 +81,6 @@
expected=[x[indices].astype(dtype), indices])
def testTopK2D(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -122,10 +110,6 @@
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
@@ -144,10 +128,6 @@
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9..f386104 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
- return [dtypes.float32]
+ return self.float_types & {dtypes.float32, dtypes.float64}
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
@@ -124,8 +124,7 @@
self.assertTrue(self._anderson_darling(y) < 2.492)
def testTruncatedNormalIsInRange(self):
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
@@ -159,7 +158,7 @@
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 55a9921..98a0770 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -122,8 +122,7 @@
expected=np.array([[2], [5]], dtype=dtype))
def testClipByValue(self):
- # TODO(b/78258593): enable integer types here too.
- for dtype in self.float_types:
+ for dtype in self.numeric_types - self.complex_types:
test_cases = [
(np.array([2, 4, 5], dtype=dtype), dtype(7)), #
(dtype(1), np.array([2, 4, 5], dtype=dtype)), #
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 04ea004..77f6eee 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -158,9 +158,6 @@
def testFloatOps(self):
for dtype in self.float_types:
- # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
- if dtype == np.float16 and self.device == "XLA_CPU":
- continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index df5c812..98a4198 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,9 +97,16 @@
])
self._numeric_tf_types = set(
self.int_tf_types | self._float_tf_types | self.complex_tf_types)
+ self.quantized_tf_types = set(
+ dtype for dtype in self._all_tf_types if dtype.is_quantized)
- self._all_types = set(
- [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ # Quantized types don't have a numpy equivalent, include them in
+ # all_tf_types but not in all_types.
+ # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+ # and remove all_types.
+ self._all_types = set(dtype.as_numpy_dtype
+ for dtype in self._all_tf_types
+ if not dtype.is_quantized)
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
self.signed_int_types = set(dtype.as_numpy_dtype
for dtype in self.int_tf_types
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d9a0257..7b2bb4a 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -132,14 +133,14 @@
// If the 2D kernel would be very large, the 1D kernel can be applied once in
// each dimension due to the symmetry of the kernel along all axis to reduce the
// computational intensity.
-std::vector<float> Make1DKernel(int64 n) {
+xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
- return kernel;
+ return xla::ConstantR1<float>(builder, kernel);
}
// Kernels with more than 16 spatial elements are considered intense and the
@@ -149,41 +150,26 @@
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+ auto depthwise_kernel = xla::Broadcast(
+ xla::Zero(builder, xla::F32),
+ {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1});
- auto diag = xla::ConvertElementType(
- xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
- 2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
return xla::Mul(
- xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]),
/*broadcast_dimensions=*/{1}),
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ Make1DKernel(builder, kernel_size[0]),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
-
- auto diag = xla::ConvertElementType(
- xla::Eq(
- xla::Broadcast(channels_iota,
- {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
- dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
- if (dim == 1) {
- return xla::Mul(
- diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1});
- }
- return xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
- /*broadcast_dimensions=*/{0});
+ auto depthwise_kernel =
+ xla::Broadcast(xla::Zero(builder, xla::F32),
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1});
+ return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]),
+ /*broadcast_dimensions=*/{dim});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -206,8 +192,8 @@
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
dimension_numbers.add_input_spatial_dimensions(1 + i);
dimension_numbers.add_output_spatial_dimensions(1 + i);
@@ -285,7 +271,8 @@
{{dims.kernel_size[0] - 1, upper_padding[0]},
{dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -294,7 +281,8 @@
/*padding=*/
{{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
output = xla::ConvGeneralDilated(
@@ -302,7 +290,8 @@
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// Add broadcasts to handle expanding from a size == 1 dimension to a
@@ -331,15 +320,15 @@
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(1 + i);
- dimension_numbers.add_output_spatial_dimensions(1 + i);
+ dimension_numbers.add_input_spatial_dimensions(i + 1);
+ dimension_numbers.add_output_spatial_dimensions(i + 1);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
- dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
- dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
xla::XlaOp output;
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
@@ -362,7 +351,8 @@
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -388,14 +378,16 @@
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
output = xla::ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.stride[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index 23d04d4..bc44301 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,21 +20,6 @@
namespace tensorflow {
bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
- // TODO(b/26783907): The CPU backend currently does not implement sort.
- if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
- return false;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index a4b6248..4b2c2ba 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -51,13 +51,14 @@
{DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
-constexpr std::array<DataType, 11> kCpuAllTypes = {
- {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
- DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 14> kCpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
-constexpr std::array<DataType, 12> kGpuAllTypes = {
- {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
- DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 15> kGpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+ DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index ef70c1f..cc7390c 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -245,6 +245,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0..cd5fd33 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers);
+ lhs_dilation, rhs_dilation, dimension_numbers,
+ feature_group_count);
}
LocalOp LocalComputationBuilder::ConvertElementType(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5..2166bb6 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366f..bb303c5 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1109,7 +1109,7 @@
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return self._client.DotGeneral(lhs, rhs, dimension_numbers)
- def Conv(self, lhs, rhs, window_strides, padding):
+ def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1117,6 +1117,7 @@
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the Conv operation.
"""
@@ -1125,10 +1126,11 @@
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
- (), dimension_numbers)
+ (), dimension_numbers,
+ feature_group_count)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation):
+ lhs_dilation, rhs_dilation, feature_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1138,6 +1140,7 @@
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
+ feature_group_count: number of feature groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1148,8 @@
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1167,8 @@
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
- rhs_dilation, dimension_numbers):
+ rhs_dilation, dimension_numbers,
+ feature_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1190,6 +1195,7 @@
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
@@ -1215,7 +1221,8 @@
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19..82103f0 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+ def testConvGeneralDilatedGroupedConvolutionF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 2, 2, 3)
+ rhs = a(2, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ feature_group_count = 2
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]],
+ [[0., 0., 0.],
+ [330., 380., 160.],
+ [0., 0., 0.],
+ [480., 530., 220.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 68bf56c..2bc50c7 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1169,6 +1169,7 @@
":hlo",
":hlo_matchers",
":hlo_module_group",
+ ":hlo_module_group_metadata",
":hlo_parser",
":hlo_proto",
"//tensorflow/compiler/xla:test",
@@ -2560,6 +2561,7 @@
],
deps = [
":hlo",
+ ":hlo_module_group",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -2591,6 +2593,26 @@
],
)
+tf_cc_test(
+ name = "hlo_pass_pipeline_test",
+ srcs = ["hlo_pass_pipeline_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_parser",
+ ":hlo_pass_pipeline",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "hlo_cse",
srcs = ["hlo_cse.cc"],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4ef1dff..75dae7a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -754,11 +754,12 @@
};
auto reshape_if_necessary = [&](HloInstruction* hlo) {
+ hlo = as_type(hlo, dot->shape().element_type());
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
hlo = computation_->AddInstruction(
HloInstruction::CreateReshape(dot->shape(), hlo));
}
- return as_type(hlo, dot->shape().element_type());
+ return hlo;
};
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index b864c37..9f8d0ee 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -24,7 +24,7 @@
namespace xla {
// A pass which performs algebraic simplifications.
-class AlgebraicSimplifier : public HloPassInterface {
+class AlgebraicSimplifier : public HloModulePass {
public:
// Given shapes 'from_shape' and 'to_shape', determines if it is valid to
// bitcast from 'from_shape' to 'to_shape' after considering platform
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3fc1ba2..2047f89 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3233,17 +3233,18 @@
class DotStrengthReductionTest
: public AlgebraicSimplifierTest,
public ::testing::WithParamInterface<
- ::testing::tuple<int, int, int, bool, bool>> {};
+ ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
int m, k, n;
bool transpose_lhs, transpose_rhs;
- std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+ PrimitiveType element_type;
+ std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
- Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
- Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
- Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
- Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
- Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+ Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
+ Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
+ Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
+ Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
+ Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -3285,7 +3286,7 @@
DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
::testing::Values(1, 2), ::testing::Bool(),
- ::testing::Bool()));
+ ::testing::Bool(), ::testing::Values(F32, BF16)));
struct DotOfConcatTestSpec {
int64 m;
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index 79d37f0..5b625bf 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -25,7 +25,7 @@
// Normally these would live in the algebraic simplifier, but we want to run
// this to fixpoint (this pass reaches fixed point in one execution) before we
// run the DotDecomposer.
-class BatchDotSimplification : public HloPassInterface {
+class BatchDotSimplification : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override;
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 76e3217..147f3ae 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -26,7 +26,7 @@
// A pass which rewrites batch norm operations into more operations. Breaking a
// big operation into smaller operations helps leverage our generic fusion
// logic.
-class BatchNormExpander : public HloPassInterface {
+class BatchNormExpander : public HloModulePass {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index 5dcd31b..cb3d12f 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -31,7 +31,7 @@
// optimization pipeline followed by a DCE pass. If other passes are needed
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
// changed made by this pass.
-class BFloat16ConversionFolding : public HloPassInterface {
+class BFloat16ConversionFolding : public HloModulePass {
public:
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 30b6346..f48e925 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -25,7 +25,7 @@
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
// support BF16 input/output or mixed precision, according to the passed-in
// backend-specific BF16 support rules.
-class BFloat16Normalization : public HloPassInterface {
+class BFloat16Normalization : public HloModulePass {
public:
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
@@ -48,7 +48,7 @@
// use mixed precision; it removes mixed precision even if the backend supports
// it. This pass is used to make the HLO module valid for other HLO passes which
// do not support mixed precision.
-class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+class BFloat16MixedPrecisionRemoval : public HloModulePass {
public:
BFloat16MixedPrecisionRemoval() {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 1ee6497..6a62439 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -58,7 +58,7 @@
// BFloat16ConversionFolding. If other passes are needed after this pass, run
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
// pass.
-class BFloat16Propagation : public HloPassInterface {
+class BFloat16Propagation : public HloModulePass {
public:
explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index c5cd88b..08c4aff 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -25,7 +25,7 @@
// For every kCall operation in the main computation, we inline the body of the
// called function, and proceed recursively.
-class CallInliner : public HloPassInterface {
+class CallInliner : public HloModulePass {
public:
using InlinedInstructionMap =
std::unordered_map<HloInstruction*, HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 3de50cb..2223ad6 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -25,7 +25,7 @@
// HLO pass that removes kConditional with a constant predicate, replacing them
// with their true or false computation as appropriate.
-class ConditionalSimplifier : public HloPassInterface {
+class ConditionalSimplifier : public HloModulePass {
public:
absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index 4988947..ce0138e 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -25,7 +25,7 @@
// A pass which rewrites convolutions with feature_group_count > 1 into
// convolutions with feature_group_count = 1.
-class ConvolutionFeatureGroupConverter : public HloPassInterface {
+class ConvolutionFeatureGroupConverter : public HloModulePass {
public:
ConvolutionFeatureGroupConverter() {}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index d308f6b..c097089 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -43,7 +43,7 @@
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.
-class CopyInsertion : public HloPassInterface {
+class CopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8cc522a..bf62798 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -180,6 +180,7 @@
":runtime_conv2d_mkl",
":runtime_fft",
":runtime_fork_join",
+ ":runtime_key_value_sort",
":runtime_matmul",
":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
@@ -624,6 +625,18 @@
)
cc_library(
+ name = "runtime_key_value_sort",
+ srcs = ["runtime_key_value_sort.cc"],
+ hdrs = ["runtime_key_value_sort.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_fork_join",
srcs = ["runtime_fork_join.cc"],
hdrs = ["runtime_fork_join.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index 59437e8..becee3f 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -31,7 +31,7 @@
// called canonical convolutions). This pass expands non-canonical convolutions
// into reshapes and canonical convolutions, so that these non-canonical
// convolutions can run faster.
-class ConvCanonicalization : public HloPassInterface {
+class ConvCanonicalization : public HloModulePass {
public:
explicit ConvCanonicalization(
const TargetMachineFeatures* target_machine_features)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index d49f7d7..076235f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -30,7 +30,7 @@
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
-class CpuCopyInsertion : public HloPassInterface {
+class CpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 6af724b..a39a9d4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -23,7 +23,7 @@
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class CpuHloSupportChecker : public HloPassInterface {
+class CpuHloSupportChecker : public HloModulePass {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 8a44c38..7e15909 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -74,6 +74,30 @@
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
+extern const char* const kKeyValueSortPREDSymbolName =
+ "__xla_cpu_runtime_KeyValueSortPRED";
+extern const char* const kKeyValueSortS8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS8";
+extern const char* const kKeyValueSortU8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU8";
+extern const char* const kKeyValueSortS16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS16";
+extern const char* const kKeyValueSortU16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU16";
+extern const char* const kKeyValueSortF16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF16";
+extern const char* const kKeyValueSortS32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS32";
+extern const char* const kKeyValueSortU32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU32";
+extern const char* const kKeyValueSortF32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF32";
+extern const char* const kKeyValueSortS64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS64";
+extern const char* const kKeyValueSortU64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU64";
+extern const char* const kKeyValueSortF64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF64";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index aa0e967..e6345e0 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -63,6 +63,18 @@
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
+extern const char* const kKeyValueSortPREDSymbolName;
+extern const char* const kKeyValueSortS8SymbolName;
+extern const char* const kKeyValueSortU8SymbolName;
+extern const char* const kKeyValueSortS16SymbolName;
+extern const char* const kKeyValueSortU16SymbolName;
+extern const char* const kKeyValueSortF16SymbolName;
+extern const char* const kKeyValueSortS32SymbolName;
+extern const char* const kKeyValueSortU32SymbolName;
+extern const char* const kKeyValueSortF32SymbolName;
+extern const char* const kKeyValueSortS64SymbolName;
+extern const char* const kKeyValueSortU64SymbolName;
+extern const char* const kKeyValueSortF64SymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index df8c2a6..c32f253 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -495,8 +495,150 @@
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
- // TODO(b/26783907): Implement sort on CPU.
- return Unimplemented("Sort is not implemented on CPU.");
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
+ auto keys = sort->operand(0);
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ ShapeIndex keys_shape_index({});
+ ShapeIndex values_shape_index({});
+ if (values != nullptr) {
+ keys_shape_index = ShapeIndex({0});
+ values_shape_index = ShapeIndex({1});
+ }
+ auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+ auto keys_destination_address =
+ EmitBufferPointer(keys_destination, keys->shape());
+ auto values_destination = GetAllocationSlice(*sort, values_shape_index);
+ llvm::Value* values_destination_address = nullptr;
+
+ // The sort is implemented in-place, therefore we first copy the operand
+ // buffer to the output buffer if they are not the same.
+ if (keys_destination != GetAllocationSlice(*keys)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(keys);
+ int64 keys_size = ByteSizeOf(keys->shape());
+ MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, keys_size);
+ }
+ if (values != nullptr) {
+ values_destination_address =
+ EmitBufferPointer(values_destination, values->shape());
+ if (values_destination != GetAllocationSlice(*values)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(values);
+ int64 values_size = ByteSizeOf(values->shape());
+ MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, values_size);
+ }
+ }
+
+ // Normalize the shape and the dimension to sort.
+ Shape normalized_keys_shape =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ keys->shape());
+ int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
+ keys->shape().layout())[sort->dimensions(0)];
+
+ int64 sort_dimension_elements =
+ normalized_keys_shape.dimensions(physical_dimension_to_sort);
+ int64 higher_dimensions = 1;
+ for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
+ higher_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+ int64 lower_dimensions = 1;
+ for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
+ i > physical_dimension_to_sort; --i) {
+ lower_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+
+ PrimitiveType keys_type = keys->shape().element_type();
+ const char* fn_name = nullptr;
+ llvm::Type* keys_native_type = nullptr;
+ switch (keys_type) {
+ case PRED:
+ fn_name = runtime::kKeyValueSortPREDSymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S8:
+ fn_name = runtime::kKeyValueSortS8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case U8:
+ fn_name = runtime::kKeyValueSortU8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S16:
+ fn_name = runtime::kKeyValueSortS16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case U16:
+ fn_name = runtime::kKeyValueSortU16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case F16:
+ fn_name = runtime::kKeyValueSortF16SymbolName;
+ keys_native_type = b_.getHalfTy()->getPointerTo();
+ break;
+ case S32:
+ fn_name = runtime::kKeyValueSortS32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case U32:
+ fn_name = runtime::kKeyValueSortU32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case F32:
+ fn_name = runtime::kKeyValueSortF32SymbolName;
+ keys_native_type = b_.getFloatTy()->getPointerTo();
+ break;
+ case S64:
+ fn_name = runtime::kKeyValueSortS64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case U64:
+ fn_name = runtime::kKeyValueSortU64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case F64:
+ fn_name = runtime::kKeyValueSortF64SymbolName;
+ keys_native_type = b_.getDoubleTy()->getPointerTo();
+ break;
+ default:
+ return Unimplemented(
+ "Element type %s not supported in the Sort op on CPU.",
+ PrimitiveType_Name(keys_type));
+ }
+
+ llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
+ b_.getVoidTy(),
+ {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
+ b_.getInt8PtrTy(), b_.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto* key_value_sort_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, key_value_sort_type));
+ key_value_sort_func->setCallingConv(llvm::CallingConv::C);
+ key_value_sort_func->setDoesNotThrow();
+ key_value_sort_func->setOnlyAccessesArgMemory();
+ Call(key_value_sort_func,
+ {PointerCast(keys_destination_address, keys_native_type),
+ b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
+ b_.getInt64(lower_dimensions),
+ values != nullptr
+ ? PointerCast(values_destination_address, b_.getInt8PtrTy())
+ : llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
+ values->shape().element_type())
+ : 0)});
+
+ if (values != nullptr) {
+ llvm_ir::EmitTuple(GetIrArrayFor(sort),
+ {keys_destination_address, values_destination_address},
+ &b_, module_);
+ }
+ return Status::OK();
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3df9946..daafef4 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -163,6 +163,12 @@
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
+ // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+ BufferAllocation::Slice GetAllocationSlice(
+ const HloInstruction& hlo, const ShapeIndex& index = {}) const {
+ return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
+ }
+
private:
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index b4c0c09..ede7f43 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -142,6 +142,7 @@
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
+ opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index a99cd99..3822d53 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -60,7 +60,7 @@
// own embedded computation, which is compiled as a parallel compute function,
// and which is invoked from a kCall instruction that is lowered in codegen to
// a runtime parallel fork/join call.
-class ParallelTaskAssigner : public HloPassInterface {
+class ParallelTaskAssigner : public HloModulePass {
public:
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
new file mode 100644
index 0000000..e0e7deb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace {
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+template <typename KeyType>
+void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements);
+}
+
+// For floating point numbers, we want a total order comparator. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0. Also we want to have a stable sort, so if the keys are the
+// same, we compare the index values.
+template <typename KeyType>
+bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
+ bool lhs_is_negative = std::signbit(lhs);
+ bool rhs_is_negative = std::signbit(rhs);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(lhs);
+ bool rhs_nan = std::isnan(rhs);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
+ }
+ if (lhs != rhs) {
+ return lhs < rhs;
+ }
+ return lhs_index < rhs_index;
+}
+
+template <>
+void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<double, int64>& lhs,
+ const std::pair<double, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<float, int64>& lhs,
+ const std::pair<float, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
+ int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<Eigen::half, int64>& lhs,
+ const std::pair<Eigen::half, int64>& rhs) -> bool {
+ return LessThan(
+ Eigen::half_impl::half_to_float(lhs.first), lhs.second,
+ Eigen::half_impl::half_to_float(rhs.first), rhs.second);
+ });
+}
+
+template <typename KeyType>
+void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ // High-level idea of the iteration/sorting logic:
+ // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
+ // dimension to sort, c is the product of the more minor dimensions (set to 1
+ // if b is the most minor dimension), and a is the product of the more major
+ // dimensions (set to 1 if b is the most major dimension). There are a * c
+ // many rows that we need to sort. We iterate through these, calculate a
+ // 'base_offset' value which points to the first element in that row, and add
+ // i * c for accessing the 'i'-th element in that row.
+
+ int64 sort_dimension_elements = b;
+ int64 num_iteration_elements = a * c;
+ int64 sort_dimension_offset = c;
+
+ std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
+ new std::pair<KeyType, int64>[sort_dimension_elements]);
+ std::unique_ptr<std::string[]> reordered_values(
+ new std::string[sort_dimension_elements]);
+ for (int64 index = 0; index < num_iteration_elements; ++index) {
+ // 'index' can be split into two values which index into the 'c' dimension
+ // and the 'a' dimension, respectively. 'index' % 'c' is the index into the
+ // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
+ // calculating the base offset, we need to multiply the index into the 'a'
+ // dimension with 'b' * 'c'.
+ // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
+ int64 base_offset =
+ index % sort_dimension_offset +
+ (index - index % sort_dimension_offset) * sort_dimension_elements;
+ // TODO(b/26783907): We could define a custom iterator class that references
+ // both arrays. Then we could avoid the intermediate copy. However this
+ // would become more complicated, and it is not clear if the benefit is high
+ // enough.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ row_to_sort[i] =
+ std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
+ }
+ KeyValueSort(row_to_sort.get(), sort_dimension_elements);
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
+ }
+ if (values == nullptr) {
+ continue;
+ }
+
+ // Reorder the values according to the order defined by the keys.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index =
+ (base_offset + row_to_sort[i].second * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+
+ reordered_values[i] = std::string(values + memory_index,
+ values_primitive_type_size_in_bytes);
+ }
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index = (base_offset + i * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+ memcpy(values + memory_index, reordered_values[i].c_str(),
+ values_primitive_type_size_in_bytes);
+ }
+ }
+}
+} // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
+ int8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
+ uint8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
+ int16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
+ uint16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
+ int32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
+ uint32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
+ int64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
+ uint64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
new file mode 100644
index 0000000..28e35e8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
+// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
+// If 'values' is not nullptr, the elements in 'values' are reordered in such a
+// way that if the element at index 'i' in 'keys' was moved to index 'j', the
+// element at index 'i' in 'values' is also moved to index 'j' (which means that
+// the same elements correspond to each other as before).
+extern void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS8(
+ tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU8(
+ tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS16(
+ tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU16(
+ tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS32(
+ tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU32(
+ tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS64(
+ tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU64(
+ tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+}
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index bf98064..9ec0c8f 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@
#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -202,6 +203,18 @@
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index c55206e..4b129c9 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -180,3 +180,17 @@
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cpu_key_value_sort_test",
+ srcs = ["cpu_key_value_sort_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
new file mode 100644
index 0000000..3934c03
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuKeyValueSortTest : public CpuCodegenTest {};
+
+TEST_F(CpuKeyValueSortTest, SortR1) {
+ const string hlo_text = R"(
+HloModule KeyValueSort
+
+ENTRY main {
+ a = f32[10] parameter(0)
+
+ ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: call void @__xla_cpu_runtime_KeyValueSort
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/true);
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index c326beb..aaa41fc 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -25,7 +25,7 @@
// A pass which replaces all fusion instructions with the equivalent un-fused
// instructions.
-class Defuser : public HloPassInterface {
+class Defuser : public HloModulePass {
public:
Defuser() {}
~Defuser() override {}
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index ba2a674..b3549ac 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -24,7 +24,7 @@
namespace {
// Pass which strips control dependencies from all instructions in the module.
-class ControlDepRemover : public HloPassInterface {
+class ControlDepRemover : public HloModulePass {
public:
ControlDepRemover() = default;
absl::string_view name() const override { return "control-dep-remover"; }
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index 7be70ad..46dcc3a 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -30,7 +30,7 @@
//
// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
// and BFloat16MixedPrecisionRemoval.
-class Despecializer : public HloPassInterface {
+class Despecializer : public HloModulePass {
public:
Despecializer();
absl::string_view name() const override { return "despecializer"; }
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index fc38e31..40e7a3b 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -23,7 +23,7 @@
// DotDecomposer is a pass which decomposes batch Dot operations into a
// sequence of smaller (R2) Dot operations.
-class DotDecomposer : public HloPassInterface {
+class DotDecomposer : public HloModulePass {
public:
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
DotDecomposer(bool decompose_batch_dot = true)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e07..515267e 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
- if (prim_type != F32) {
- // TODO(b/34339814): Implement inverse erf for F64.
+ if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
return Unimplemented(
"Inverse erf is only implemented for element "
- "type F32.");
+ "types F16, F32 and F64.");
}
- auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+ // Upcast half to float.
+ if (prim_type == F16) {
+ x = b_->CreateFPExt(x, b_->getFloatTy());
+ }
+
+ auto get_float = [&](const double f) {
+ return llvm::ConstantFP::get(x->getType(), f);
};
- auto multiply_add = [&](absl::Span<const float> coefficients,
+ auto multiply_add = [&](absl::Span<const double> coefficients,
llvm::Value* w) {
- llvm::Value* p = getFloat(coefficients.front());
+ llvm::Value* p = get_float(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = FAdd(FMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), get_float(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
- // The approximation has the form:
- // w = log((1-x)*(1+x))
+ // The approximation has the form (float version):
+ // w = -log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {x->getType()});
- llvm::Value* w = FNeg(
- Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(Call(
+ logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
llvm::Value* p_addr =
- llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+ llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
- // Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, b_);
- {
- llvm::Value* lw = FSub(w, getFloat(2.5f));
- absl::Span<const float> lq{
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- llvm::Value* p = multiply_add(lq, lw);
- Store(p, p_addr);
+ if (prim_type == F16 || prim_type == F32) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(2.5f));
+ absl::Span<const double> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ Store(p, p_addr);
+ }
+
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+ absl::Span<const double> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ } else {
+ DCHECK(prim_type == F64);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(3.125));
+ absl::Span<const double> c{
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356};
+ llvm::Value* p = multiply_add(c, lw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+ SetToFirstInsertPoint(if_data_second.true_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+ absl::Span<const double> t1{
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635};
+ llvm::Value* p = multiply_add(t1, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data_second.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+ absl::Span<const double> t2{
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221};
+ llvm::Value* p = multiply_add(t2, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
}
-
- // Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, b_);
- {
- llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
- llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- absl::Span<const float> gq{
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
- llvm::Value* p = multiply_add(gq, gw);
- Store(p, p_addr);
- }
-
- SetToFirstInsertPoint(if_data.after_block, b_);
llvm::Value* p = Load(p_addr);
- return FMul(p, x);
+ x = FMul(p, x);
+ // Trunc back to half if needed.
+ if (prim_type == F16) {
+ x = b_->CreateFPTrunc(x, b_->getHalfTy());
+ }
+ return x;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index 3cccec9..986970f 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -26,7 +26,7 @@
// Flattening associates each call site with a unique computation (for
// sequential calling contexts) This simplifies buffer assignment and
// points-to analysis (see b/36865746 for details).
-class FlattenCallGraph : public HloPassInterface {
+class FlattenCallGraph : public HloModulePass {
public:
absl::string_view name() const override { return "flatten-call-graph"; }
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index 7bd9ea5..2b39359 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -23,7 +23,7 @@
// This pass rewrites gather operations into (roughly) while loops of dynamic
// slices. This lets backends that don't support gather directly to
// nevertheless have a minimum level of support.
-class GatherExpander : public HloPassInterface {
+class GatherExpander : public HloModulePass {
public:
absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 64b9683..cbee4db 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -68,9 +68,7 @@
# srcs = [
# "partition_assignment_test.cc",
# ],
-# tags = [
-# "requires-gpu-sm35",
-# ],
+# tags = tf_cuda_tests_tags(),
# deps = [
# ":partition_assignment",
# "//tensorflow/core:stream_executor_no_cuda",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1..85f3682 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,21 +29,51 @@
namespace xla {
namespace gpu {
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+ const HloCustomCallInstruction* cudnn_call,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ operand_buffers_(std::move(operand_slices)),
+ result_buffer_(result_slice),
+ scratch_buffer_(scratch_slice),
+ tuple_result_buffer_(tuple_result_slice) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms));
- params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
- params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
- params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
+ switch (params.kind) {
+ case CudnnConvKind::kForward:
+ params.input_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ params.filter_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ params.output_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ break;
+ case CudnnConvKind::kBackwardInput:
+ params.input_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ params.filter_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ params.output_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ params.input_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ params.output_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ break;
+ }
+
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms));
-
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91..f53bc54 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // Note that "output" here doesn't refer to the output from running this
- // thunk, but rather to the "output" of a hypothetical forward convolution
- // that corresponds to this input+filter+output triple. That is, the result
- // generated by this thunk is "output" for forward convs, "input" for
- // backward-input convs, and "filter" for backward-filter convs.
+ // operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
- BufferAllocation::Slice input_slice,
- BufferAllocation::Slice filter_slice,
- BufferAllocation::Slice output_slice,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice tuple_result_slice)
- : Thunk(Kind::kConvolution, cudnn_call),
- cudnn_call_(cudnn_call),
- input_buffer_(std::move(input_slice)),
- filter_buffer_(std::move(filter_slice)),
- output_buffer_(std::move(output_slice)),
- scratch_buffer_(std::move(scratch_slice)),
- tuple_result_buffer_(std::move(tuple_result_slice)) {}
+ BufferAllocation::Slice tuple_result_slice);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@
private:
const HloCustomCallInstruction* cudnn_call_;
- BufferAllocation::Slice input_buffer_;
- BufferAllocation::Slice filter_buffer_;
- BufferAllocation::Slice output_buffer_;
+ std::vector<BufferAllocation::Slice> operand_buffers_;
+ BufferAllocation::Slice result_buffer_;
BufferAllocation::Slice scratch_buffer_;
BufferAllocation::Slice tuple_result_buffer_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index 6e2e330..c3f5850 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -52,7 +52,7 @@
// The GPU backend does not implement a lowering for the batchnorm HLOs -- it
// expects them to be lowered to cudnn calls via this pass or to HLO soup via
// BatchNormRewriter.
-class CudnnBatchNormRewriter : public HloPassInterface {
+class CudnnBatchNormRewriter : public HloModulePass {
public:
absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f79b113..ce01895 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -30,7 +30,7 @@
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
// each and adding explicit scratch space to the CustomCalls.
-class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
+class CudnnConvolutionAlgorithmPicker : public HloModulePass {
public:
// If the `allocator` parameter is not null, we will use it to allocate temp
// memory while timing the various convolution algorithms. If it's null,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index fbe7e98..8d7c6fd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -24,7 +24,7 @@
// Rewrites plain convolutions, backwards-filter convolutions, and
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
-class CudnnConvolutionRewriter : public HloPassInterface {
+class CudnnConvolutionRewriter : public HloModulePass {
public:
absl::string_view name() const override {
return "cudnn-convolution-rewriter";
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 7e3f577..f19996e 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -32,7 +32,7 @@
// 2) The result of merging the fusion instruction into its users would not
// increase bytes transferred.
//
-class FusionMerger : public HloPassInterface {
+class FusionMerger : public HloModulePass {
public:
absl::string_view name() const override { return "fusion merger"; }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 75f414e..79c74e7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -34,15 +34,6 @@
namespace gpu {
-StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
- HloInstruction* hlo) {
- HloInstruction*& copy = hlo_to_copy_map_[hlo];
- if (copy == nullptr) {
- TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
- }
- return copy;
-}
-
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 8ffae18..4c7e38f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -25,20 +25,11 @@
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
-class GpuCopyInsertion : public HloPassInterface {
+class GpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
-
- protected:
- // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
- // duplicate copies.
- StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
-
- // A map containing all copies inserted to materialize operands of library
- // calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index bbb3340..9c64b4d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -23,7 +23,7 @@
// his pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class GpuHloSupportChecker : public HloPassInterface {
+class GpuHloSupportChecker : public HloModulePass {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881..c792dd2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
- auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+ std::vector<BufferAllocation::Slice> operand_slices;
+ operand_slices.reserve(custom_call->operand_count());
+ for (const auto* operand : custom_call->operands()) {
+ operand_slices.push_back(GetAllocationSlice(*operand));
+ }
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const auto& target = custom_call->custom_call_target();
- BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
- if (target == kCudnnConvForwardCallTarget) {
- input_slice = lhs_slice;
- filter_slice = rhs_slice;
- output_slice = conv_result_slice;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_slice = conv_result_slice;
- filter_slice = rhs_slice;
- output_slice = lhs_slice;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_slice = lhs_slice;
- filter_slice = conv_result_slice;
- output_slice = rhs_slice;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
- }
-
thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
- Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
- output_slice, scratch_slice, tuple_result_slice));
+ Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+ conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 11dc56a..e592a37 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -30,7 +30,7 @@
// targeting before running this pass.
//
// TODO(jlebar): Also pad dots.
-class PadForTensorCores : public HloPassInterface {
+class PadForTensorCores : public HloModulePass {
public:
absl::string_view name() const override { return "pad for tensor cores"; }
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index a622e89..25cdf64 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -24,7 +24,7 @@
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
// inserts Pad instructions before Convolution instructions with uncanonicalized
// padding, so that they can be lowered to cuDNN convolution.
-class PadInsertion : public HloPassInterface {
+class PadInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "pad insertion"; }
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index db4a33dc..5da6f23 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -25,15 +25,17 @@
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "gpu_codegen_test",
testonly = True,
srcs = ["gpu_codegen_test.cc"],
hdrs = ["gpu_codegen_test.h"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -48,9 +50,7 @@
tf_cc_test(
name = "gpu_copy_test",
srcs = ["gpu_copy_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -67,9 +67,7 @@
tf_cc_test(
name = "gpu_ftz_test",
srcs = ["gpu_ftz_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/core:test_main",
@@ -79,9 +77,7 @@
tf_cc_test(
name = "gpu_index_test",
srcs = ["gpu_index_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -102,9 +98,7 @@
tf_cc_test(
name = "gpu_infeed_test",
srcs = ["infeed_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -125,9 +119,7 @@
tf_cc_test(
name = "gpu_kernel_tiling_test",
srcs = ["gpu_kernel_tiling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo",
@@ -142,7 +134,7 @@
tf_cc_test(
name = "gpu_ldg_test",
srcs = ["gpu_ldg_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -159,9 +151,7 @@
tf_cc_test(
name = "gpu_noalias_test",
srcs = ["gpu_noalias_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -178,9 +168,7 @@
tf_cc_test(
name = "gpu_fusion_test",
srcs = ["gpu_fusion_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -194,9 +182,7 @@
tf_cc_test(
name = "gpu_unrolling_test",
srcs = ["gpu_unrolling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -211,9 +197,7 @@
name = "gpu_alignment_test",
testonly = True,
srcs = ["gpu_alignment_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:gpu_plugin",
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 4557983..4a624cc 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -23,7 +23,7 @@
// A pass which performs constant folding in order to avoid unnecessary
// computation on constants.
-class HloConstantFolding : public HloPassInterface {
+class HloConstantFolding : public HloModulePass {
public:
absl::string_view name() const override { return "constant_folding"; }
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index a28c035..e4857fd 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -25,7 +25,7 @@
// and identical instructions with the same operands are commoned. The pass
// iterates over the instructions in topological order which enables the pass to
// find arbitrarily large common expressions.
-class HloCSE : public HloPassInterface {
+class HloCSE : public HloModulePass {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 1fe69b1..4012042 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -33,7 +33,7 @@
//
// This pass does not remove dead parameter instructions, as parameter
// instructions cannot be deleted.
-class HloDCE : public HloPassInterface {
+class HloDCE : public HloModulePass {
public:
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index d36631f..c0bf1b9 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -30,7 +30,7 @@
// used to break an HLO graph edge connecting two instructions with different
// sharding. If a set of connected instructions have all the same sharding, no
// kDomain instruction will be placed.
-class HloDomainIsolator : public HloPassInterface {
+class HloDomainIsolator : public HloModulePass {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index 97bc8ef..0fc30fb 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -26,7 +26,7 @@
// Removes all the kDomain instructions of a given kind from the input module,
// and calls the normalizer to propagate the properties on the possibly new born
// instructions.
-class HloDomainRemover : public HloPassInterface {
+class HloDomainRemover : public HloModulePass {
public:
// Creates a new HloDomainRemover object tasked at removing all the kDomain
// instructions of a given kind.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 81d6d69..bea5cba 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -29,7 +29,7 @@
// Verifies that the domain instructions are consistent, and the each domain is
// surrounded by the same metadata.
-class HloDomainVerifier : public HloPassInterface {
+class HloDomainVerifier : public HloModulePass {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 44ded2c..4d2a942 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -25,7 +25,7 @@
// inserting Convert ops. This allows a backend to support an element type while
// only actually implementing the Convert op for that element type. This is
// generally not the fastest approach, but it works.
-class HloElementTypeConverter : public HloPassInterface {
+class HloElementTypeConverter : public HloModulePass {
public:
// eliminate_type is the type to eliminate as the input or output of ops,
// using Convert ops to replace it with replace_with_type.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e905f29..ad58833 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2910,6 +2910,26 @@
return os << ToString(kind);
}
+bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
+ const HloInstruction* const& rhs) const {
+ if (rhs == nullptr) {
+ // Nothing compares less than nullptr.
+ return false;
+ }
+ if (lhs == nullptr) {
+ return true;
+ }
+ auto lhs_module = lhs->GetModule();
+ auto rhs_module = rhs->GetModule();
+ CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
+ (lhs_module != nullptr && rhs_module != nullptr));
+ if (lhs_module != nullptr &&
+ lhs_module->unique_id() != rhs_module->unique_id()) {
+ return lhs_module->unique_id() < rhs_module->unique_id();
+ }
+ return lhs->unique_id() < rhs->unique_id();
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 1ef8cd5..d615df0 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1693,21 +1693,9 @@
// To make the iteration order over the map deterministic, the comparator
// should not be using the pointer values, but rather an intrinsic property of
// the hlo. Exception: null pointer values compare less than non-null.
-//
-// Note that this cannot be used for HLO instructions across multiple modules
-// since the id of HLO instructions are only unique within each HLO module.
struct HloPtrComparator {
bool operator()(const HloInstruction* const& lhs,
- const HloInstruction* const& rhs) const {
- if (rhs == nullptr) {
- // Nothing compares less than nullptr.
- return false;
- }
- if (lhs == nullptr) {
- return true;
- }
- return lhs->unique_id() < rhs->unique_id();
- }
+ const HloInstruction* const& rhs) const;
};
template <typename ValueT>
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 5e02868..9964c6f 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -90,7 +90,7 @@
// A pass which schedules the HLO instructions in a module. The HloModule's
// schedule field is set to the resulting HloSchedule using
// HloModule::set_schedule.
-class HloMemoryScheduler : public HloPassInterface {
+class HloMemoryScheduler : public HloModulePass {
public:
// size_function is the function returning the number of bytes required for a
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
@@ -109,7 +109,7 @@
// A trivial pass which clears the schedule currently set on the
// HloModule. After this pass runs HloModudle::has_schedule will return false.
-class HloDescheduler : public HloPassInterface {
+class HloDescheduler : public HloModulePass {
public:
HloDescheduler() = default;
~HloDescheduler() override = default;
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3bc2d13..735804e 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config);
+ virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@
const std::unordered_map<HloComputation*, HloComputation*>& replacements);
const string& name() const { return name_; }
+ void set_name(string name) { name_ = std::move(name); }
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -255,7 +257,7 @@
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_identifiers);
- const string name_;
+ string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;
std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 12ca234..d472211 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -28,7 +28,7 @@
// Sweeps through live instructions which cross computation boundaries (kWhile),
// and removes code at dead shape indices.
//
-class HloModuleDCE : public HloPassInterface {
+class HloModuleDCE : public HloModulePass {
public:
~HloModuleDCE() override {}
absl::string_view name() const override { return "hlo-module-dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 9c01862..83352ef 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -392,22 +392,28 @@
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- absl::make_unique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::vector<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
- companion_set->insert(instruction1);
- companion_set->insert(instruction2);
+ companion_set->push_back(instruction1);
+ companion_set->push_back(instruction2);
companion_set_index_[instruction1] = companion_sets_.size() - 1;
companion_set_index_[instruction2] = companion_sets_.size() - 1;
} else if (!ContainsKey(companion_set_index_, instruction1)) {
- companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+ companion_sets_[companion_set_index_[instruction2]]->push_back(
+ instruction1);
companion_set_index_[instruction1] = companion_set_index_[instruction2];
} else if (!ContainsKey(companion_set_index_, instruction2)) {
- companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+ companion_sets_[companion_set_index_[instruction1]]->push_back(
+ instruction2);
companion_set_index_[instruction2] = companion_set_index_[instruction1];
} else if (companion_set_index_[instruction1] !=
companion_set_index_[instruction2]) {
- companion_sets_[companion_set_index_[instruction1]]->insert(
- Companions(instruction2).begin(), Companions(instruction2).end());
+ // At any point while building the companion sets, each instruction belongs
+ // to at most 1 companion set, so the union of two companion sets is
+ // concatenating two disjoint sets.
+ absl::c_copy(Companions(instruction2),
+ std::back_inserter(
+ *companion_sets_[companion_set_index_[instruction1]]));
int64 index_to_remove = companion_set_index_[instruction2];
for (HloInstruction* hlo : Companions(instruction2)) {
companion_set_index_[hlo] = companion_set_index_[instruction1];
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 768b0c7..278d94c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -169,14 +169,14 @@
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
- const std::unordered_set<HloInstruction*>& Companions(
+ const std::vector<HloInstruction*>& Companions(
const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
// Returns the companion set at the given index.
- const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+ const std::vector<HloInstruction*>& companion_set(int64 index) const {
CHECK_LT(index, companion_sets_.size());
return *companion_sets_[index];
}
@@ -187,7 +187,7 @@
}
// Returns the list of all companion sets in the HLO module group.
- const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+ const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
companion_sets() const {
return companion_sets_;
}
@@ -247,8 +247,7 @@
void DumpCollectedStats() const;
// List of all companion instructions sets in the module.
- std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
- companion_sets_;
+ std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
index ebf790b..b7b12cb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -17,6 +17,7 @@
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -137,6 +138,69 @@
::testing::ElementsAre(op::Parameter()));
}
+// Tests that the order of companion instructions in the companion set doesn't
+// change across runs.
+TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
+ // A simple while loop template for core i sending to core i+1.
+ constexpr char text[] = R"(
+HloModule module_%d
+
+while_cond {
+ ROOT p = pred[] constant(true)
+}
+
+while_body {
+ param = s32[] parameter(0)
+ token.s = token[] after-all()
+ token.r = token[] after-all()
+ send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
+ send-done = token[] send-done(send), channel_id=%d
+ recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
+ ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
+}
+
+ENTRY entry {
+ while_init = s32[] constant(1)
+ ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
+}
+)";
+
+ // Try creating the module and the metadata kTrialCount times and check the
+ // companion instructions remain in the same order.
+ const int64 kTrialCount = 5;
+ const int64 kDeviceCount = 10;
+ std::vector<int64> companion_order;
+
+ for (int64 t = 0; t < kTrialCount; ++t) {
+ HloModuleGroup group(TestName());
+ for (int64 i = 0; i < kDeviceCount; ++i) {
+ const int64 send_channel = i;
+ const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
+ recv_channel, recv_channel)));
+ group.push_back(std::move(module));
+ }
+ ASSERT_EQ(group.modules().size(), kDeviceCount);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto metadata,
+ HloModuleGroupMetadata::Build(group.modules()));
+ ASSERT_EQ(metadata->companion_sets().size(), 1);
+
+ std::vector<int64> module_ids;
+ for (HloInstruction* companion : *metadata->companion_sets()[0]) {
+ module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
+ }
+
+ if (t == 0) {
+ companion_order = module_ids;
+ } else {
+ EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89..37197b2 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(absl::string_view str, const HloModuleConfig& config)
- : lexer_(str), config_(config) {}
+ explicit HloParser(absl::string_view str) : lexer_(str) {}
- // Runs the parser. Returns false if an error occurred.
- bool Run();
-
- // Returns the parsed HloModule.
- std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+ // Runs the parser and constructs the resulting HLO in the given (empty)
+ // HloModule. Returns false if an error occurred.
+ bool Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -98,8 +95,8 @@
const string& name, const optional<Shape>& shape = nullopt);
// ParseXXX returns false if an error occurred.
- bool ParseHloModule();
- bool ParseComputations();
+ bool ParseHloModule(HloModule* module);
+ bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
@@ -293,9 +290,7 @@
computation_pool_;
HloLexer lexer_;
- std::unique_ptr<HloModule> module_;
std::vector<std::unique_ptr<HloComputation>> computations_;
- const HloModuleConfig config_;
std::vector<string> error_;
// Function that gets invoked when we try to resolve an instruction
@@ -349,9 +344,9 @@
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run() {
+bool HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule();
+ return ParseHloModule(module);
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -366,7 +361,7 @@
}
// ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
@@ -385,22 +380,20 @@
return false;
}
- module_ = absl::make_unique<HloModule>(name, config_);
-
- if (!ParseComputations()) {
+ module->set_name(name);
+ if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
- TF_CHECK_OK(
- module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
@@ -416,21 +409,20 @@
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
- module_->AddEmbeddedComputation(std::move(computations_[i]));
+ module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
- auto computation =
- module_->AddEntryComputation(std::move(computations_[i]));
+ auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
@@ -3247,53 +3239,62 @@
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
- HloParser parser(str, config);
- if (!parser.Run()) {
+ auto module = absl::make_unique<HloModule>(/*name=*/"", config);
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
- return parser.ConsumeHloModule();
+ return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
- HloModuleConfig config;
- return ParseHloString(str, config);
+ auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return std::move(module);
+}
+
+Status ParseHloString(absl::string_view str, HloModule* module) {
+ TF_RET_CHECK(module->computation_count() == 0);
+ HloParser parser(str);
+ if (!parser.Run(module)) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return Status::OK();
}
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
auto builder = absl::make_unique<HloComputation::Builder>(string(name));
string root_name;
TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(string(name), config);
+ auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
module->AddEntryComputation(std::move(computation));
return std::move(module);
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a18..3696035 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,23 @@
// For details about the syntax accepted by this parser, see
// g3doc/hlo_parser.md.
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config);
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
+
// Parses the text for a single HLO operation into an HLO module with a function
// that runs that operation (with the same parameters) as its entry computation.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index f1ad0f9..fdaac34 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,15 +26,45 @@
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
-// organize a sequence of passes.
+// organize a sequence of passes. An HLO pass should not extend this class
+// directly; it should extend HloModulePass or HloModuleGroupPass.
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
virtual absl::string_view name() const = 0;
- // Run the pass on the given HLO module. Return whether it modified the
+ // Run the pass on the given HLO module. Returns whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
+
+ // Run the pass on the given HLO module group. Returns whether it modified the
+ // module group. Ideally, the module group variant would be named "Run" as
+ // well, but C++ does not handle overloaded virtual methods well.
+ virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
+};
+
+// Base class for passes which are module-scoped.
+class HloModulePass : public HloPassInterface {
+ public:
+ // Runs the pass on a module group by iterating through each module in the
+ // group.
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
+ changed |= module_changed;
+ }
+ return changed;
+ };
+};
+
+// Base class for passes which are module-group scoped. These passes cannot run
+// on an HLO module.
+class HloModuleGroupPass : public HloPassInterface {
+ public:
+ StatusOr<bool> Run(HloModule* module) override {
+ return InternalError("Module group pass cannot be run on a module");
+ }
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e4ed0d..8c2f928 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,7 +17,6 @@
#include <functional>
-#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -29,108 +28,128 @@
#include "tensorflow/core/platform/logging.h"
namespace xla {
-namespace {
-using absl::StrAppend;
-using absl::StrCat;
+template <typename HloT>
+Status HloPassPipeline::RunInvariantCheckers(
+ HloT* hlo, absl::string_view after_pass_name) {
+ for (auto& invariant_checker : invariant_checkers_) {
+ VLOG(1) << " Invariant checker " << invariant_checker->name();
+ StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
+ VLOG(1) << " Invariant checker done " << invariant_checker->name();
+ if (!changed_status.ok()) {
+ VLOG(2) << "Failed invariant check:";
+ XLA_VLOG_LINES(2, hlo->ToString());
+ return Status(changed_status.status().code(),
+ absl::StrCat(changed_status.status().error_message(),
+ "\n\nFailed after ", after_pass_name));
+ }
+ TF_RET_CHECK(!changed_status.ValueOrDie())
+ << "invariant checkers must not change the graph";
+ }
+ return Status::OK();
+}
-void DumpModuleGraph(const HloModule& module, const string& message) {
+template <typename HloT>
+StatusOr<bool> HloPassPipeline::RunPassesInternal(
+ HloT* hlo, absl::Span<HloPassInterface* const> passes) {
+ string last_pass_name = "pipeline-start";
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
+ bool changed = false;
+ for (HloPassInterface* pass : passes) {
+ VLOG(1) << " HLO pass " << pass->name();
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/pass->name());
+ TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
+ changed |= pass_changed;
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
+ last_pass_name = string(pass->name());
+ }
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/"pipeline-end");
+ return changed;
+}
+
+std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
+ const DebugOptions& debug_options) {
+ auto repeated_field = debug_options.xla_disable_hlo_passes();
+ tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
+ repeated_field.end());
+ if (!disabled_pass_names.empty()) {
+ VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
+ << absl::StrJoin(disabled_pass_names, ", ");
+ }
+
+ std::vector<HloPassInterface*> enabled_passes;
+ for (auto& pass : passes_) {
+ if (disabled_pass_names.count(string(pass->name())) == 0) {
+ enabled_passes.push_back(pass.get());
+ }
+ }
+ return enabled_passes;
+}
+
+void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ const string& proto_dump_path =
+ module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
+ if (!proto_dump_path.empty()) {
+ static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+ static auto* const module_id_to_pass_number =
+ new tensorflow::gtl::FlatMap<int64, int64>();
+
+ tensorflow::mutex_lock lock(mu);
+ const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+
+ const string filename = SanitizeFileName(
+ absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+ pass_number, name(), after_pass_name));
+
+ TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
+ MakeHloProto(module), proto_dump_path, filename));
+ }
+
+ const string message =
+ StrCat("after ", after_pass_name, ", before ", before_pass_name);
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(3) << "HLO " << message << ":";
XLA_VLOG_LINES(3, module.ToString());
}
-void DumpModuleProto(const HloModule& module, const string& dump_to,
- const string& pipeline_name, const string& pass_name) {
- static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
- static auto* const module_id_to_pass_number =
- new tensorflow::gtl::FlatMap<int64, int64>();
-
- tensorflow::mutex_lock lock(mu);
- const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
-
- const string mod_name = SanitizeFileName(
- absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
- pass_number, pipeline_name, pass_name));
-
- TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
- dump_to, mod_name));
+void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ for (const HloModule* module : module_group.modules()) {
+ MaybeDumpHlo(*module, after_pass_name, before_pass_name);
+ }
}
-} // namespace
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
run_called_ = true;
- VLOG(1) << "Running HLO pass pipeline " << name();
+ VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
+ << name();
- auto repeated_field =
- module->config().debug_options().xla_disable_hlo_passes();
- tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
- repeated_field.end());
- if (!disabled_passes.empty()) {
- VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << absl::StrJoin(disabled_passes, ", ");
+ return RunPassesInternal(module,
+ GetEnabledPasses(module->config().debug_options()));
+}
+
+StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
+ run_called_ = true;
+
+ VLOG(1) << "Running HLO pass pipeline on module group "
+ << module_group->name() << ": " << name();
+
+ if (module_group->modules().empty()) {
+ VLOG(1) << "Module group is empty. Nothing to do.";
+ return false;
}
- auto run_invariant_checkers = [this,
- module](const string& message) -> Status {
- for (auto& invariant_checker : invariant_checkers_) {
- VLOG(1) << " Invariant checker " << invariant_checker->name();
- StatusOr<bool> changed_status = invariant_checker->Run(module);
- VLOG(1) << " Invariant checker done " << invariant_checker->name();
- if (!changed_status.ok()) {
- VLOG(2) << "Module failed invariant check:";
- XLA_VLOG_LINES(2, module->ToString());
- return Status(changed_status.status().code(),
- StrCat(changed_status.status().error_message(),
- "\n\nFailed ", message));
- }
- TF_RET_CHECK(!changed_status.ValueOrDie())
- << "invariant checkers must not change the graph";
- }
- return Status::OK();
- };
-
- string prefix = StrCat(name(), ": pipeline start");
- bool changed = false;
- string message;
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("before running pipeline: ", name())));
- const string xla_dump_per_pass_hlo_proto_to =
- module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- "pipeline_start");
- }
-
- for (auto& pass : passes_) {
- if (disabled_passes.count(string(pass->name())) > 0) {
- VLOG(1) << " Skipping HLO pass " << pass->name()
- << ", disabled by --xla_disable_hlo_passes";
- continue;
- }
-
- VLOG(1) << " HLO pass " << pass->name();
-
- // Emit label containing: "after foo-pass, before bar-pass".
- message.clear();
- StrAppend(&message, prefix, ", before ", pass->name());
- DumpModuleGraph(*module, message);
-
- TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("after running pass: ", pass->name())));
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- string(pass->name()));
- }
-
- changed |= changed_this_pass;
- prefix.clear();
- StrAppend(&prefix, name(), ": after ", pass->name());
- }
- DumpModuleGraph(*module, prefix + ", pipeline end");
- return changed;
+ return RunPassesInternal(
+ module_group,
+ GetEnabledPasses(module_group->module(0).config().debug_options()));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 1d41a4d..09e7033 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -22,6 +22,7 @@
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -61,10 +62,45 @@
return *pass;
}
- // Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
private:
+ // Returns the set of passes which are enabled. DebugOptions can selectively
+ // disable passes via --xla_disable_hlo_passes flag.
+ std::vector<HloPassInterface*> GetEnabledPasses(
+ const DebugOptions& debug_options);
+
+ // Maybe dumps the given module or module group depending on flag values
+ // contained in DebugOptions of module config.
+ void MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+ void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+
+ // Runs the invariant checker on the given HLO. HloT can be either HloModule
+ // or HloModuleGroup.
+ template <typename HloT>
+ Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
+
+ // Helper which runs the given pass on the given HLO. HloT can be either
+ // HloModule or HloModuleGroup.
+ template <typename HloT>
+ StatusOr<bool> RunPassesInternal(HloT* hlo,
+ absl::Span<HloPassInterface* const> passes);
+
+ // Helpers which run the given passes on the given HLO construct. These
+ // helpers enable templating of the core of the pipeline logic by providing
+ // HloModule and HloModuleGroup specific methods with the same name.
+ static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
+ return pass->Run(module);
+ }
+ static StatusOr<bool> RunHelper(HloPassInterface* pass,
+ HloModuleGroup* module_group) {
+ return pass->RunOnModuleGroup(module_group);
+ }
+
const string name_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
new file mode 100644
index 0000000..ee8cb12
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloPassPipelineTest : public HloVerifiedTestBase {
+ protected:
+ StatusOr<HloModuleGroup> ParseModuleGroup(
+ absl::Span<const string> hlo_strings) {
+ HloModuleGroup group(TestName());
+ for (const string& hlo_string : hlo_strings) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ group.push_back(std::move(module));
+ }
+ return std::move(group);
+ }
+};
+
+// A module pass which renames instructions named 'foo' to 'bar'.
+class FooToBarModulePass : public HloModulePass {
+ absl::string_view name() const override { return "foo2bar"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "foo") {
+ instruction->SetAndSanitizeName("bar");
+ changed = true;
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// A module group pass which renames instructions named 'baz' to 'qux'.
+class BazToQuxModuleGroupPass : public HloModuleGroupPass {
+ absl::string_view name() const override { return "baz2qux"; }
+
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "baz") {
+ instruction->SetAndSanitizeName("qux");
+ changed = true;
+ }
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// An invariant checker pass which returns an error if there exists an
+// instruction named 'bar'.
+class BarBlowerUpper : public HloModulePass {
+ absl::string_view name() const override { return "bar-blower-upper"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "bar") {
+ return InternalError("Module has instruction named bar");
+ }
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(HloPassPipelineTest, ModulePassChanged) {
+ // Test an HLO module pass which changes a module.
+ const string module_str = R"(
+HloModule ModulePassChanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "foo");
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(root->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
+ // Test an HLO module pass which does not change a module.
+ const string module_str = R"(
+HloModule ModulePassUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT blahblah = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(HloPassPipelineTest, MixedPipeline) {
+ // Test a pipeline with both a module pass and a module group pass.
+ const string module_0_str = R"(
+HloModule MixedPipeline.1
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT baz = f32[] multiply(a, b)
+}
+)";
+ const string module_1_str = R"(
+HloModule MixedPipeline.0
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
+ ParseModuleGroup({module_0_str, module_1_str}));
+
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root0 =
+ module_group.module(0).entry_computation()->root_instruction();
+ HloInstruction* root1 =
+ module_group.module(1).entry_computation()->root_instruction();
+ EXPECT_EQ(root0->name(), "baz");
+ EXPECT_EQ(root1->name(), "foo");
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ pipeline.RunOnModuleGroup(&module_group));
+ EXPECT_TRUE(changed);
+
+ EXPECT_EQ(root0->name(), "qux");
+ EXPECT_EQ(root1->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, InvariantChecker) {
+ const string module_str = R"(
+HloModule InvariantChecker
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ {
+ // Run a pipeline with just the invariant checker. It should not fail
+ // because there is no 'bar' instruction in the module.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+ }
+
+ {
+ // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
+ // which fails if there is an instruction named 'bar'.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after foo2bar"));
+ }
+
+ {
+ // Run the invariant-checker only pipeline again. It should fail this time.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after pipeline-start"));
+ }
+}
+
+TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
+ // Running a module group pass on a module should produce an error.
+ const string module_str = R"(
+HloModule ModuleGroupPassOnModule
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Module group pass cannot be run on a module"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index bd6dd79..a438671 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1198,6 +1198,12 @@
<< HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
+ // Initialize pass object state.
+ computation_peak_memory_.clear();
+ rematerialized_computations_.clear();
+ instructions_rematerialized_ = 0;
+ net_instructions_added_ = 0;
+
TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index e2aaf18..7330d73 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -33,7 +33,7 @@
// CSE will undo the effects of this optimization and should not be run after
// this pass. In general, this pass should be run very late, immediately before
// code generation.
-class HloRematerialization : public HloPassInterface {
+class HloRematerialization : public HloModulePass {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index d1cf644..fa34bdd 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -22,7 +22,7 @@
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
// one arbitrarily to use and delete the others.
-class HloSubcomputationUnification : public HloPassInterface {
+class HloSubcomputationUnification : public HloModulePass {
public:
absl::string_view name() const override {
return "subcomputation-unification";
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 773fc7d..8549487 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -131,6 +131,7 @@
CHECK_LE(operand_number, 2);
return operand_number == 0 || index.empty();
+ case HloOpcode::kDomain:
case HloOpcode::kTuple:
// These instructions always pass through their operands transparently.
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cb..6eb6658 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1057,6 +1057,7 @@
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
+ TF_RET_CHECK(!module->name().empty());
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 42e3027..0cde4a3 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -151,7 +151,7 @@
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
-class HloVerifier : public HloPassInterface {
+class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index 85bb4a8..9c48b7d 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -25,7 +25,7 @@
// Pass which replaces all implicit broadcasts with their equivalent sequence of
// explicit broadcast and reshape instructions.
-class ImplicitBroadcastRemover : public HloPassInterface {
+class ImplicitBroadcastRemover : public HloModulePass {
public:
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index df9cbab..3e238f9 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -366,7 +366,7 @@
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
// unconditionally add to the regular HLO pass pipeline.
-class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
+class IndexedArrayAnalysisPrinterPass : public HloModulePass {
public:
absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index efa8ed3..e20af08 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -24,7 +24,7 @@
// A pass which performs inlining. Which can result, for example, in functions
// that were previously being mapped by Map instead directly applied to the
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
-class Inliner : public HloPassInterface {
+class Inliner : public HloModulePass {
public:
~Inliner() override = default;
absl::string_view name() const override { return "inline"; }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index c1fde8e..7e1196f 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -56,7 +56,7 @@
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
-class InstructionFusion : public HloPassInterface {
+class InstructionFusion : public HloModulePass {
public:
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index cf54503..e29c199 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -281,7 +281,7 @@
// HLO pass which assigns layouts to all instructions in the HLO module while
// satisfying all necessary invariants and minimizing cost.
-class LayoutAssignment : public HloPassInterface {
+class LayoutAssignment : public HloModulePass {
public:
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d2c5265..0344626 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -44,7 +44,7 @@
// Note that the reachability map is updated based on the original computation.
// This works because the reachability is monotonically increasing with
// instruction fusion.
-class MultiOutputFusion : public HloPassInterface {
+class MultiOutputFusion : public HloModulePass {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17..ac2f796 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@
}
/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ if (name.empty()) {
+ return "";
+ }
string result = name;
- CHECK(!result.empty()) << "name should not be empty";
char c = static_cast<unsigned char>(result[0]);
if (!isalpha(c) && c != '_') {
result[0] = '_';
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index 256b231..4bb2242 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -29,7 +29,7 @@
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
// purposes of experimenting with the effects of reduced-precision storage of
// intermediate values.
-class ReducePrecisionInsertion : public HloPassInterface {
+class ReducePrecisionInsertion : public HloModulePass {
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
public:
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1e86a08..a3db439 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -24,7 +24,7 @@
// This now only moves them outputward across elementwise ops all whose operands
// are equivalent Reshapes or Transposes, but in future could potentially move
// them inputward also.
-class ReshapeMover : public HloPassInterface {
+class ReshapeMover : public HloModulePass {
public:
absl::string_view name() const override { return "reshape-mover"; }
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 14f062c..559a85d 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -20,7 +20,7 @@
namespace xla {
-class ScatterExpander : public HloPassInterface {
+class ScatterExpander : public HloModulePass {
public:
absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 74bdf2a..7194b2c 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1665,10 +1665,11 @@
if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %d) to match RHS "
- "input feature dimension * feature_group_count (value %d); "
+ "input feature dimension * feature_group_count (value %d * %d = %d); "
"got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features * feature_group_count,
+ input_features, kernel_input_features, feature_group_count,
+ kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 5d1cd1c..ec09dff 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -28,8 +28,14 @@
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
- VLOG(1) << stream->DebugStreamPointers()
- << " StreamPool reusing existing stream";
+ if (stream->ok()) {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool reusing existing stream";
+ } else {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " stream was not ok, StreamPool deleting";
+ stream = nullptr;
+ }
}
}
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
index aaf5c37..92f4757 100644
--- a/tensorflow/compiler/xla/service/stream_pool_test.cc
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -132,5 +132,39 @@
EXPECT_EQ(stream2_ptr, stream3_ptr);
}
+TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream1->ok());
+
+ // Return the stream, but hold a handle to it.
+ se::Stream* stream1_ptr = stream1.get();
+ stream1 = nullptr;
+
+ // Now stream1 is back in the pool, force an error on the stream. Here we call
+ // a method that requires DNN support, which we know the Host platform doesn't
+ // support.
+ stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(stream1_ptr->ok());
+
+ // Borrow stream2.
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on stream1, it cannot be
+ // put back into the pool. Sadly we can't just check:
+ // EXPECT_NE(stream1_ptr, stream2_ptr);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for stream2 happens to reside in the
+ // same memory address as stream1, which has been deleted.
+ //
+ // The check that stream2->ok() serves as a good-enough check.
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 3e5aa2d..f95f982 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -23,7 +23,7 @@
// HLO pass that folds transpose operators into Dot operators, where the Dot
// operator is implemented by a GEMM kernel that can transpose its inputs.
-class TransposeFolding : public HloPassInterface {
+class TransposeFolding : public HloModulePass {
public:
using OperandIndices = std::vector<int64>;
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 8c91d6e..e126a53 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -25,7 +25,7 @@
// A pass which simplifies patterns of Tuple and GetTupleElement instructions in
// the module.
-class TupleSimplifier : public HloPassInterface {
+class TupleSimplifier : public HloModulePass {
public:
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 2dba7d7..577bad6 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -50,7 +50,7 @@
// conditions as well.
//
// TODO(b/79121449): We should also sink broadcasts of constants.
-class WhileLoopConstantSinking : public HloPassInterface {
+class WhileLoopConstantSinking : public HloModulePass {
public:
~WhileLoopConstantSinking() override = default;
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 2cdf20c..3031899 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -25,7 +25,7 @@
// HLO pass that rewrites while loops to hoist loop invariant instructions in
// the while body into the computation that contains the while instruction.
-class WhileLoopInvariantCodeMotion : public HloPassInterface {
+class WhileLoopInvariantCodeMotion : public HloModulePass {
public:
// If `hoist_constants` is true then constants are always hoisted out of while
// loop bodies. Otherwise they are only hoisted out if they enable other
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 78024f1..0bc5a01 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -30,7 +30,7 @@
// - Elements of a while loop's tuple that the loop doesn't use are removed
// from the tuple.
//
-class WhileLoopSimplifier : public HloPassInterface {
+class WhileLoopSimplifier : public HloModulePass {
public:
~WhileLoopSimplifier() override {}
absl::string_view name() const override { return "simplify-while-loops"; }
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index a7f0e20..8729412 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -21,7 +21,7 @@
// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
namespace xla {
-class ZeroSizedHloElimination : public HloPassInterface {
+class ZeroSizedHloElimination : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 623ae39..d8bb27b 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,7 @@
#include <initializer_list>
#include <string>
+#include "absl/base/macros.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -479,8 +480,7 @@
// Shorthand for testing whether a shape is of a given element type and
// sequence of dimensions.
- //
- // DEPRECATED: Use Equal() instead.
+ ABSL_DEPRECATED("Use Equal() instead.")
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
std::initializer_list<int64> dimensions);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 30e3077..fd3e3bf 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -29,6 +29,10 @@
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
@@ -150,11 +154,31 @@
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
+tf_cc_test(
+ name = "hlo_verified_test_base_test",
+ srcs = ["hlo_verified_test_base_test.cc"],
+ deps = [
+ ":hlo_test_base",
+ ":hlo_verified_test_base",
+ ":test_macros_cpu",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],
@@ -1797,7 +1821,7 @@
tf_cc_test(
name = "llvm_compiler_test",
srcs = ["llvm_compiler_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
@@ -2096,7 +2120,7 @@
name = "sample_file_test",
srcs = ["sample_file_test.cc"],
data = ["isolated_convolution.hlo"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:test",
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 53f2c3b..cc65a89 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -3,256 +3,266 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
- """Removes "gpu" from a backend list if CUDA is not enabled.
+ """Removes "gpu" from a backend list if CUDA is not enabled.
- This allows us to simply hardcode lists including "gpu" here and in the
- BUILD file, without causing failures when CUDA isn't enabled.'
+ This allows us to simply hardcode lists including "gpu" here and in the
+ BUILD file, without causing failures when CUDA isn't enabled.'
- Args:
- backends: A list of backends to filter.
+ Args:
+ backends: A list of backends to filter.
- Returns:
- The filtered list of backends.
- """
- if cuda_is_configured():
- return backends
- else:
- return [backend for backend in backends if backend != "gpu"]
-
-
-def xla_test(name,
- srcs,
- deps,
- xla_test_library_deps=[],
- backends=[],
- blacklisted_backends=[],
- args=[],
- tags=[],
- copts=[],
- data=[],
- backend_tags={},
- backend_args={},
- **kwargs):
- """Generates cc_test targets for the given XLA backends.
-
- This rule generates a cc_test target for one or more XLA backends and also a
- platform-agnostic cc_library rule. The arguments are identical to cc_test with
- two additions: 'backends' and 'backend_args'. 'backends' specifies the
- backends to generate tests for ("cpu", "gpu"), and
- 'backend_args'/'backend_tags' specifies backend-specific args parameters to
- use when generating the cc_test.
-
- The name of the cc_tests are the provided name argument with the backend name
- appended, and the cc_library target name is the provided name argument with
- "_lib" appended. For example, if name parameter is "foo_test", then the cpu
- test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
-
- The cc_library target can be used to link with other plugins outside of
- xla_test.
-
- The build rule also defines a test suite ${name} which includes the tests for
- each of the supported backends.
-
- Each generated cc_test target has a tag indicating which backend the test is
- for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
- tags can be used to gather tests for a particular backend into a test_suite.
-
- Examples:
-
- # Generates the targets: foo_test_cpu and foo_test_gpu.
- xla_test(
- name = "foo_test",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
-
- # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
- # includes the additional arg "--special_cpu_flag".
- xla_test(
- name = "bar_test",
- srcs = ["bar_test.cc"],
- backends = ["cpu", "gpu"],
- backend_args = {"cpu": ["--special_cpu_flag"]}
- deps = [...],
- )
-
- The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
- to the value 1 where ${BACKEND} is the uppercase name of the backend.
-
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- deps: Dependencies of the target.
- xla_test_library_deps: If set, the generated test targets will depend on the
- respective cc_libraries generated by the xla_test_library rule.
- backends: A list of backends to generate tests for. Supported values: "cpu",
- "gpu". If this list is empty, the test will be generated for all supported
- backends.
- blacklisted_backends: A list of backends to NOT generate tests for.
- args: Test arguments for the target.
- tags: Tags for the target.
- copts: Additional copts to pass to the build.
- data: Additional data to pass to the build.
- backend_tags: A dict mapping backend name to list of additional tags to
- use for that target.
- backend_args: A dict mapping backend name to list of additional args to
- use for that target.
- **kwargs: Additional keyword arguments to pass to native.cc_test.
- """
- test_names = []
- if not backends:
- backends = all_backends
-
- backends = [backend for backend in backends
- if backend not in blacklisted_backends]
-
- native.cc_library(
- name="%s_lib" % name,
- srcs=srcs,
- copts=copts,
- testonly=True,
- deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
- )
-
- for backend in filter_backends(backends):
- test_name = "%s_%s" % (name, backend)
- this_backend_tags = ["xla_%s" % backend]
- this_backend_copts = []
- this_backend_args = backend_args.get(backend, [])
- this_backend_data = []
- if backend == "cpu":
- backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
- elif backend == "gpu":
- backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
- this_backend_tags += ["requires-gpu-sm35"]
- elif backend in plugins:
- backend_deps = []
- backend_deps += plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
- this_backend_tags += plugins[backend]["tags"]
- this_backend_args += plugins[backend]["args"]
- this_backend_data += plugins[backend]["data"]
+ Returns:
+ The filtered list of backends.
+ """
+ if cuda_is_configured():
+ return backends
else:
- fail("Unknown backend %s" % backend)
+ return [backend for backend in backends if backend != "gpu"]
- if xla_test_library_deps:
- for lib_dep in xla_test_library_deps:
- backend_deps += ["%s_%s" % (lib_dep, backend)]
+def xla_test(
+ name,
+ srcs,
+ deps,
+ xla_test_library_deps = [],
+ backends = [],
+ blacklisted_backends = [],
+ args = [],
+ tags = [],
+ copts = [],
+ data = [],
+ backend_tags = {},
+ backend_args = {},
+ **kwargs):
+ """Generates cc_test targets for the given XLA backends.
- tf_cc_test(
- name=test_name,
- srcs=srcs,
- tags=tags + backend_tags.get(backend, []) + this_backend_tags,
- extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
- this_backend_copts,
- args=args + this_backend_args,
- deps=deps + backend_deps,
- data=data + this_backend_data,
- **kwargs)
+ This rule generates a cc_test target for one or more XLA backends and also a
+ platform-agnostic cc_library rule. The arguments are identical to cc_test with
+ two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "gpu"), and
+ 'backend_args'/'backend_tags' specifies backend-specific args parameters to
+ use when generating the cc_test.
- test_names.append(test_name)
+ The name of the cc_tests are the provided name argument with the backend name
+ appended, and the cc_library target name is the provided name argument with
+ "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+ test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
- native.test_suite(name=name, tests=test_names)
+ The cc_library target can be used to link with other plugins outside of
+ xla_test.
-def xla_test_library(name,
- srcs,
- hdrs=[],
- deps=[],
- backends=[]):
- """Generates cc_library targets for the given XLA backends.
+ The build rule also defines a test suite ${name} which includes the tests for
+ each of the supported backends.
- This rule forces the sources to be compiled for each backend so that the
- backend specific macros could expand correctly. It's useful when test targets
- in different directories referring to the same sources but test with different
- arguments.
+ Each generated cc_test target has a tag indicating which backend the test is
+ for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+ tags can be used to gather tests for a particular backend into a test_suite.
- Examples:
+ Examples:
- # Generates the targets: foo_test_library_cpu and foo_test_gpu.
- xla_test_library(
- name = "foo_test_library",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
- # Then use the xla_test rule to generate test targets:
- xla_test(
- name = "foo_test",
- srcs = [],
- backends = ["cpu", "gpu"],
- deps = [...],
- xla_test_library_deps = [":foo_test_library"],
- )
+ # Generates the targets: foo_test_cpu and foo_test_gpu.
+ xla_test(
+ name = "foo_test",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- hdrs: Headers for the target.
- deps: Dependencies of the target.
- backends: A list of backends to generate libraries for.
- Supported values: "cpu", "gpu". If this list is empty, the
- library will be generated for all supported backends.
- """
+ # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+ # includes the additional arg "--special_cpu_flag".
+ xla_test(
+ name = "bar_test",
+ srcs = ["bar_test.cc"],
+ backends = ["cpu", "gpu"],
+ backend_args = {"cpu": ["--special_cpu_flag"]}
+ deps = [...],
+ )
- if not backends:
- backends = all_backends
+ The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+ to the value 1 where ${BACKEND} is the uppercase name of the backend.
- for backend in filter_backends(backends):
- this_backend_copts = []
- if backend in ["cpu", "gpu"]:
- backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
- elif backend in plugins:
- backend_deps = plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
- else:
- fail("Unknown backend %s" % backend)
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ xla_test_library_deps: If set, the generated test targets will depend on the
+ respective cc_libraries generated by the xla_test_library rule.
+ backends: A list of backends to generate tests for. Supported values: "cpu",
+ "gpu". If this list is empty, the test will be generated for all supported
+ backends.
+ blacklisted_backends: A list of backends to NOT generate tests for.
+ args: Test arguments for the target.
+ tags: Tags for the target.
+ copts: Additional copts to pass to the build.
+ data: Additional data to pass to the build.
+ backend_tags: A dict mapping backend name to list of additional tags to
+ use for that target.
+ backend_args: A dict mapping backend name to list of additional args to
+ use for that target.
+ **kwargs: Additional keyword arguments to pass to native.cc_test.
+ """
+ test_names = []
+ if not backends:
+ backends = all_backends
+
+ backends = [
+ backend
+ for backend in backends
+ if backend not in blacklisted_backends
+ ]
native.cc_library(
- name = "%s_%s" % (name, backend),
+ name = "%s_lib" % name,
srcs = srcs,
+ copts = copts,
testonly = True,
- hdrs = hdrs,
- copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
- + this_backend_copts,
- deps = deps + backend_deps,
+ deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
)
+ for backend in filter_backends(backends):
+ test_name = "%s_%s" % (name, backend)
+ this_backend_tags = ["xla_%s" % backend]
+ this_backend_copts = []
+ this_backend_args = backend_args.get(backend, [])
+ this_backend_data = []
+ if backend == "cpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ elif backend == "gpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+ this_backend_tags += tf_cuda_tests_tags()
+ elif backend in plugins:
+ backend_deps = []
+ backend_deps += plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ this_backend_tags += plugins[backend]["tags"]
+ this_backend_args += plugins[backend]["args"]
+ this_backend_data += plugins[backend]["data"]
+ else:
+ fail("Unknown backend %s" % backend)
-def generate_backend_suites(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- native.test_suite(name="%s_tests" % backend,
- tags = ["xla_%s" % backend])
+ if xla_test_library_deps:
+ for lib_dep in xla_test_library_deps:
+ backend_deps += ["%s_%s" % (lib_dep, backend)]
+ tf_cc_test(
+ name = test_name,
+ srcs = srcs,
+ tags = tags + backend_tags.get(backend, []) + this_backend_tags,
+ extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ args = args + this_backend_args,
+ deps = deps + backend_deps,
+ data = data + this_backend_data,
+ **kwargs
+ )
-def generate_backend_test_macros(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- manifest = ""
- if backend in plugins:
- manifest = plugins[backend]["disabled_manifest"]
+ test_names.append(test_name)
- native.cc_library(
- name="test_macros_%s" % backend,
- testonly = True,
- srcs = ["test_macros.cc"],
- hdrs = ["test_macros.h"],
- copts = [
- "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
- "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
- ],
- deps = [
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:lib",
- "//tensorflow/core:regexp_internal",
- "//tensorflow/core:test",
- ])
+ native.test_suite(name = name, tests = test_names)
+
+def xla_test_library(
+ name,
+ srcs,
+ hdrs = [],
+ deps = [],
+ backends = []):
+ """Generates cc_library targets for the given XLA backends.
+
+ This rule forces the sources to be compiled for each backend so that the
+ backend specific macros could expand correctly. It's useful when test targets
+ in different directories referring to the same sources but test with different
+ arguments.
+
+ Examples:
+
+ # Generates the targets: foo_test_library_cpu and foo_test_gpu.
+ xla_test_library(
+ name = "foo_test_library",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+ # Then use the xla_test rule to generate test targets:
+ xla_test(
+ name = "foo_test",
+ srcs = [],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ xla_test_library_deps = [":foo_test_library"],
+ )
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ hdrs: Headers for the target.
+ deps: Dependencies of the target.
+ backends: A list of backends to generate libraries for.
+ Supported values: "cpu", "gpu". If this list is empty, the
+ library will be generated for all supported backends.
+ """
+
+ if not backends:
+ backends = all_backends
+
+ for backend in filter_backends(backends):
+ this_backend_copts = []
+ if backend in ["cpu", "gpu"]:
+ backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
+ elif backend in plugins:
+ backend_deps = plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ native.cc_library(
+ name = "%s_%s" % (name, backend),
+ srcs = srcs,
+ testonly = True,
+ hdrs = hdrs,
+ copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ deps = deps + backend_deps,
+ )
+
+def generate_backend_suites(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ native.test_suite(
+ name = "%s_tests" % backend,
+ tags = ["xla_%s" % backend],
+ )
+
+def generate_backend_test_macros(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ manifest = ""
+ if backend in plugins:
+ manifest = plugins[backend]["disabled_manifest"]
+
+ native.cc_library(
+ name = "test_macros_%s" % backend,
+ testonly = True,
+ srcs = ["test_macros.cc"],
+ hdrs = ["test_macros.h"],
+ copts = [
+ "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
+ "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ "//tensorflow/core:test",
+ ],
+ )
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c52..8bd0a72 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
+Status VerifiedHloModule::Verify() {
+ if (computation_count() == 0) {
+ // The computation was never built. Nothing to verify.
+ return Status::OK();
+ }
+ return verifier_.Run(this).status();
+}
+
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+ Status status = Verify();
+ if (!status.ok()) {
+ ADD_FAILURE() << "HloVerifier failed on module " << name()
+ << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+ << ": " << status;
+ }
+}
+
HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
bool allow_mixed_precision)
: HloTestBase(
/*verifier_layout_sensitive=*/layout_sensitive,
- /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
- // We can't call the ASSERT or EXPECT test macros in destructors, so we
- // perform HLO verification in TearDown, and use the CHECK here to ensure
- // users don't accidentally override the verification.
- CHECK(tear_down_called_)
- << "TearDown was never called; subclasses of HloVerifiedTestBase that "
- << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
- EXPECT_FALSE(tear_down_called_)
- << "TearDown called more than once; it should be called exactly once.";
- tear_down_called_ = true;
- if (module_) {
- VerifyModule(module_.get());
- }
- for (int i = 0; i < modules_.size(); ++i) {
- VerifyModule(modules_.at(i).get());
- }
- HloTestBase::TearDown();
-}
-
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- xla::StatusOr<bool> mutated = verifier().Run(module);
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
- }
-}
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+ verifier_layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = HloTestBase::CreateNewModule();
+ module_ = CreateNewVerifiedModule(TestName());
}
return *module_;
}
HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
- modules_.emplace_back(HloTestBase::CreateNewModule());
+ modules_.emplace_back(CreateNewVerifiedModule(name));
return modules_.back().get();
}
void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
- VerifyModule(module_.get());
+ module_ = CreateNewVerifiedModule(TestName());
+ TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+ module_->VerifyOrAddFailure("after parsing");
}
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text, const HloModuleConfig& config) {
+ auto module = CreateNewVerifiedModule(TestName());
+ TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+ TF_RETURN_IF_ERROR(module->Verify());
+ return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+ const string& name) {
+ return absl::make_unique<VerifiedHloModule>(
+ name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+ allow_mixed_precision_in_hlo_verifier_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa..388a99b 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@
#include <memory>
#include <utility>
+#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+ VerifiedHloModule(const string& name, const HloModuleConfig& config,
+ bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
+ : HloModule(name, config),
+ verifier_(verifier_layout_sensitive,
+ allow_mixed_precision_in_hlo_verifier) {}
+
+ ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+ // Verifies the module using HloVerifier and returns the status.
+ Status Verify();
+
+ // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+ // included in the failure message.
+ void VerifyOrAddFailure(const string& message);
+
+ private:
+ HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive = false,
- bool allow_mixed_precision = false);
- ~HloVerifiedTestBase() override;
+ HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
- // Performs verification on the default HloModule returned by module().
- // Automatically called by the testing framework for each test.
- //
- // REQUIRED: subclasses that override TearDown() must call this explicitly.
- void TearDown() override;
-
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule& module();
+
+ ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
+ // Parses the given string and returns module as a VerifiedHloModule.
+ StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
+
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule* CreateNewModule(const string& name = TestName());
- private:
- void VerifyModule(HloModule* module);
+ // Creates and returns a verified HLO module with the given name.
+ std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+ const string& name = TestName());
+ private:
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
//
// Lazily populated. Access via module().
- std::unique_ptr<HloModule> module_;
- // Populated by calls to CreateNewModule.
- std::vector<std::unique_ptr<HloModule>> modules_;
+ std::unique_ptr<VerifiedHloModule> module_;
- bool tear_down_called_ = false;
+ // Populated by calls to CreateNewModule.
+ std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
+
+ bool verifier_layout_sensitive_;
+ bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000..5c0263e
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+ // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+ // Use module() to lazily create an empty module, build it up, and verify no
+ // failures.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+ // Use module() to lazily create an empty module and build up an invalid
+ // module.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+
+ *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+ // Call CreateNewModule and build up a valid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+ // Call CreateNewModule and build up a invalid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+
+ *module->entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ ParseAndVerifyModule(hlo_string);
+ EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[1234] add(x,y)
+}
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a40c2d7..2cc33ab 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -412,6 +412,7 @@
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, //
+ R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, //
R2Spec{
511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, //
R2Spec{
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 7abd865..8b1b9e1 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -763,9 +763,7 @@
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
-// Test while nodes that share the while body computation.
-// TODO(b/37245345): Fails on GPU backend.
-TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
+TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeShape(F32, {10})};
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index 09ab4ed..b6dcfc4 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -8,6 +8,10 @@
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "raw_api_test_lib",
@@ -57,7 +61,7 @@
size = "medium",
srcs = [],
args = ["--xla_test_device=XLA_GPU"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":raw_api_test_lib",
"//tensorflow/compiler/jit:xla_gpu_device",
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index b3f5d92..9a8f62b 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -149,7 +149,7 @@
num_devices = num_workers * num_gpus
dev_list = ["/replica:0/task:0/device:CPU:0"
for _ in range(num_devices)]
- with self.test_session():
+ with self.cached_session():
input_tensors = self._buildInitialVars(shape, dev_list)
un_op = lambda x: math_ops.div(
x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 7846814..01ee870 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -43,7 +43,7 @@
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -83,7 +83,7 @@
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -113,7 +113,7 @@
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@@ -152,7 +152,7 @@
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@@ -166,7 +166,7 @@
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -190,7 +190,8 @@
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
@@ -211,7 +212,7 @@
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@@ -236,7 +237,7 @@
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -260,7 +261,7 @@
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -289,7 +290,7 @@
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -323,7 +324,7 @@
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@@ -346,7 +347,7 @@
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -368,7 +369,7 @@
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -410,7 +411,7 @@
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 9afe3df..18d40fc 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -27,6 +27,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.util import deprecation
__all__ = [
'expectation',
@@ -66,7 +67,7 @@
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
+ `tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -141,7 +142,7 @@
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `q.log_prob`.
sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
+ `tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -188,6 +189,12 @@
return log_mean_of_values
+@deprecation.deprecated(
+ '2018-10-01',
+ 'The tf.contrib.bayesflow library has moved to '
+ 'TensorFlow Probability (https://github.com/tensorflow/probability). '
+ 'Use `tfp.monte_carlo.expectation` instead.',
+ warn_once=True)
def expectation(f, samples, log_prob=None, use_reparametrization=True,
axis=0, keep_dims=False, name=None):
r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
@@ -236,17 +243,17 @@
Example Use:
```python
- bf = tf.contrib.bayesflow
- ds = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
num_draws = int(1e5)
- p = ds.Normal(loc=0., scale=1.)
- q = ds.Normal(loc=1., scale=2.)
- exact_kl_normal_normal = ds.kl_divergence(p, q)
+ p = tfd.Normal(loc=0., scale=1.)
+ q = tfd.Normal(loc=1., scale=2.)
+ exact_kl_normal_normal = tfd.kl_divergence(p, q)
# ==> 0.44314718
- approx_kl_normal_normal = bf.expectation(
+ approx_kl_normal_normal = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
@@ -260,9 +267,9 @@
num_draws = int(1e5)
p = ds.Gamma(concentration=1., rate=1.)
q = ds.Gamma(concentration=2., rate=3.)
- exact_kl_gamma_gamma = ds.kl_divergence(p, q)
+ exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
# ==> 0.37999129
- approx_kl_gamma_gamma = bf.expectation(
+ approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
@@ -278,7 +285,7 @@
KL-divergence, the following is preferred:
```python
- approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence(
+ approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
f=bf.kl_reverse,
p_log_prob=q.log_prob,
q=p,
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
index e36f7f3..316da9e 100644
--- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -61,7 +61,7 @@
n = itr.get_next()
expected = list(self.COMMON_ROW_KEYS)
expected.reverse()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
@@ -84,7 +84,7 @@
expected_keys.reverse()
expected_values = list(self.COMMON_VALUES)
expected_values.reverse()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
@@ -125,7 +125,7 @@
expected_keys = list(self.COMMON_ROW_KEYS)
expected_values = list(self.COMMON_VALUES)
expected_tuples = zip(expected_keys, expected_values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elem in enumerate(expected_tuples):
@@ -144,7 +144,7 @@
itr = ds.make_initializable_iterator()
n = itr.get_next()
expected_key = self.COMMON_ROW_KEYS[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
output = sess.run(n)
@@ -163,7 +163,7 @@
def runSampleKeyPairsTest(self, ds, expected_key_pairs):
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elems in enumerate(expected_key_pairs):
@@ -219,7 +219,7 @@
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="r1", end="")
itr = ds.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
@@ -227,7 +227,7 @@
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="", end="r3")
itr = ds.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
@@ -235,7 +235,7 @@
ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
@@ -253,7 +253,7 @@
ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 51e0c2e..af7006b 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -579,13 +579,6 @@
const int end_index =
partition_boundaries[non_empty_partitions[root_idx]][j + 1]
.start_index;
- CHECK(bucket_ids_and_dimensions(start_index, 1) ==
- bucket_ids_and_dimensions(end_index - 1, 1))
- << "For bucket " << bucket_ids_and_dimensions(start_index, 0)
- << " the dimension was "
- << bucket_ids_and_dimensions(start_index, 1) << " and for "
- << bucket_ids_and_dimensions(end_index - 1, 0) << " "
- << bucket_ids_and_dimensions(end_index - 1, 1);
if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
// 0-dimension case which has a first bucket for catch all feature.
CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
index ccb8509..cc22504 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
@@ -45,7 +45,7 @@
eps = 0.2
- with self.test_session():
+ with self.cached_session():
predictions_tensor = constant_op.constant(
prediction_logits, dtype=dtypes.float32)
loss_for_positives, _ = losses.per_example_exp_loss(
@@ -84,7 +84,7 @@
predictions = np.array(
[[0.123], [23.2], [233], [52], [3]], dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
loss_tensor, _ = losses.per_example_squared_loss(labels, weights,
predictions)
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 789dab8..77242b3 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -17,7 +17,7 @@
Current Status
--------------
-CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/install_windows)
+CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows)
for instructions on how to install a pre-built TensorFlow package on Windows.
### Current known limitations
diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
index d5e14e7..f5431ca 100644
--- a/tensorflow/contrib/coder/python/ops/coder_ops_test.py
+++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
@@ -45,7 +45,7 @@
decoded = coder_ops.range_decode(
encoded, array_ops.shape(data), cdf, precision=14)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(*sess.run((data, decoded)))
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index d7583be..3b0e8f6 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -53,11 +53,14 @@
srcs = ["xla.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/compiler/jit:xla_ops_py",
+ "//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
],
)
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 42b3b9f..3e631b5 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -173,7 +173,7 @@
class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([[3.]])
y_nc = math_ops.matmul(x, x, name="not_compiled")
with jit.experimental_jit_scope():
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 60f5af1..0aae695 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""xla provides experimental xla support API."""
+"""xla is an experimental library that provides XLA support APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.compiler.jit.ops import xla_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
@@ -51,6 +55,30 @@
])
+def compile(computation, inputs=None): # pylint: disable=redefined-builtin
+ """Builds an operator that compiles and runs `computation` with XLA.
+
+ Args:
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
+
+ `computation` may return a list of operations and tensors. Tensors must
+ come before operations in the returned list. The return value of
+ `compile` is a list of tensors corresponding to the tensors from the
+ output of `computation`.
+
+ All `Operation`s returned from `computation` will be executed when
+ evaluating any of the returned output tensors.
+ inputs: A list of input tensors or `None` (equivalent to an empty list).
+
+ Returns:
+ A list of output tensors.
+ """
+ # pylint: disable=protected-access
+ return _compile_internal(computation, inputs)
+
+
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
@@ -206,3 +234,122 @@
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
+
+
+def _compile_internal(computation, inputs=None):
+ """Builds graph operators that compiles and symbolically executes computation.
+
+ Args:
+ computation: A Python function that builds the computation to compile and
+ execute.
+ inputs: A list of input tensors or `None` (equivalent to `[]`). Its order
+ should match ordering of computation arguments.
+ Returns:
+ A list of output tensors from computation.
+ Raises:
+ ValueError: If any element in computation outputs is neither an operations
+ or a value that can be converted to tensor.
+ TypeError: If `inputs` is not a list or tuple.
+ """
+ if inputs is None:
+ inputs = []
+
+ if not isinstance(inputs, collections.Sequence):
+ raise TypeError('inputs must be a list')
+
+ # Converts inputs to Tensors.
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
+ input_arity = len(inputs)
+
+ arg_error = tpu_function.check_function_argument_count(
+ computation, input_arity, infeed_queue=None)
+ if arg_error is not None:
+ raise TypeError(
+ 'Supplied computation cannot be called with the specified inputs. You '
+ 'specified %d inputs: %s, but the computation needs %s' %
+ (input_arity, str([i.name for i in inputs[0]]), arg_error))
+
+ cluster_name = ops.get_default_graph().unique_name('cluster')
+ pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
+ context = XLACompileContext(name=cluster_name, pivot=pivot)
+ try:
+ context.Enter()
+
+ # Add identity ops so even unused inputs are 'consumed' by the
+ # computation.
+ computation_inputs = [
+ array_ops.identity(x, name='input_{}'.format(i))
+ for i, x in enumerate(inputs)
+ ]
+
+ # Only resource variables work inside an XLA computation, so turn on
+ # resource variables for the computation.
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ vscope.set_use_resource(True)
+
+ outputs = computation(*computation_inputs)
+
+ # Restore variable scope after computation.
+ vscope.set_use_resource(saved_use_resource)
+
+ # If the computation returns `None`, make it an empty tuple.
+ if outputs is None:
+ outputs = tuple()
+ # If the computation only returned one value, make it a tuple.
+ if not isinstance(outputs, collections.Sequence):
+ outputs = (outputs,)
+
+ # Append `no_op` here so that return value of this function always contains
+ # at least one op that can trigger XlaLaunch node.
+ outputs += (control_flow_ops.no_op(),)
+ try:
+ outputs = [
+ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
+ for o in outputs
+ ]
+ except Exception as e:
+ raise ValueError(
+ 'XLA computation function return values must all either be Operations'
+ ' or convertible to Tensors. Got error: "%s"' % str(e))
+
+ # Separates the returned Operations and Tensors.
+ output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
+ output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
+
+ if outputs != output_tensors + output_operations:
+ raise ValueError(
+ 'XLA computation function must return zero or more Tensor values '
+ 'followed by zero or more Operations.')
+ output_arity = len(output_tensors)
+
+ new_output_tensors = []
+ for t in output_tensors:
+ with ops.device(t.device if t.device else ''):
+ new_output_tensors.append(array_ops.identity(t))
+
+ output_tensors = new_output_tensors
+ context.ExitResult(output_tensors)
+ finally:
+ context.report_unsupported_operations()
+ context.Exit()
+
+ outputs = [
+ xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i))
+ for i in xrange(output_arity)
+ ]
+
+ with ops.control_dependencies(output_operations):
+ if output_arity == 0:
+ # When XLA computation returns only operations and no tensors, a NoOp
+ # dependent on the operations in outputs is returned. Otherwise final
+ # outputs would be empty and there is no way to trigger returned
+ # operations.
+ return control_flow_ops.no_op(name='output_0')
+ else:
+ # Wraps the outputs in identity operators that carries control
+ # dependencies.
+ return [
+ array_ops.identity(outputs[i], name='output_%d' % i)
+ for i in xrange(output_arity)
+ ]
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5a66748..c59d368 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -413,6 +413,31 @@
self._testOneLSTMParamsSize(num_layers, num_units, input_size,
direction)
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testLSTMParamsSizeShape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ constant_op.constant([4]), 200, 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, constant_op.constant([200]), 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, 200, constant_op.constant([200]),
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+
class CudnnRNNTestInference(TensorFlowTestCase):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index e2c9bc8..5b493f4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -173,16 +173,6 @@
self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
return median_time
- def benchmark_CheapFns(self):
-
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- batch_size = 1000
- for input_size in input_sizes:
- input_dataset = dataset_ops.Dataset.from_tensor_slices(
- (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
- for map_fn, str_id in self._get_known_cheap_fns():
- self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
num_elems = np.prod(input_size)
name_template = "{}__batch_size_{}_input_size_{}_{}"
@@ -205,14 +195,28 @@
"Speedup: {}\n".format(batch_size, input_size, str_id,
(unoptimized_time / optimized_time)))
- def _get_known_cheap_fns(self):
- return [
- (lambda *args: [array_ops.identity(x) for x in args], "identity"),
- (lambda *args: [x + 1 for x in args], "add_const"),
- (lambda *args: args[0], "select"),
- (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
- "cast"),
- ]
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def _benchmark_helper(self, map_fn, str_id):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 719ce2e..be8ae5e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -25,6 +25,7 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -92,6 +93,8 @@
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
@@ -100,6 +103,53 @@
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
+ def testPrefetchBufferScalars(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasScalarValue(summary_str,
+ "Prefetch::buffer_capacity", 0)
+ self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+ 0)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testFilteredElementsStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(101).filter(
+ lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(34):
+ self.assertEqual(i * 3, sess.run(next_element))
+ if i is not 0:
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", 67.0)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", 34.0)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 2f5a444..b1b4c23 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -25,6 +25,14 @@
class StatsDatasetTestBase(test.TestCase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
+ def _assertSummaryContains(self, summary_str, tag):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasCount(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
@@ -52,3 +60,12 @@
self.assertEqual(expected_value, value.histo.sum)
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.simple_value)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index f72b827..48a7593 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -472,11 +472,8 @@
"//tensorflow/python:summary",
],
tags = [
- "manual",
"multi_and_single_gpu",
"no_pip",
- "nogpu",
- "notap",
],
)
@@ -731,12 +728,9 @@
":keras_test_lib",
],
tags = [
- "manual",
"multi_and_single_gpu",
- "no_gpu",
"no_pip",
"no_windows_gpu",
- "notap",
"notsan",
],
)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 4903714..a3e1b96 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -114,7 +114,7 @@
self.assertEqual([v.numpy() for v in left._index.values()],
list(right._index.values()))
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(
sess.run(list(left._index.values())), list(right._index.values()))
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 24cb08f..9fc1b88 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -221,9 +221,12 @@
return small_grads, large_grads
-# threading.Lock() cannot be pickled and therefore cannot be a field of
-# CollectiveKeys.
+# threading.Lock() and threading.local() cannot be pickled and therefore cannot
+# be a field of CollectiveKeys. Right now _thread_local is not necessary to be
+# an instance member of CollectiveKeys since we always create a new thread for
+# each tower.
_lock = threading.Lock()
+_thread_local = threading.local()
# TODO(yuefengz): use random key starts to avoid reusing keys?
@@ -266,14 +269,12 @@
# For instance keys without ids
self._instance_key_start = instance_key_start
- self._thread_local = threading.local()
-
def _get_thread_local_object(self):
# We make instance key without key ids thread local so that it will work
# with MirroredStrategy and distribute coordinator.
- if not hasattr(self._thread_local, 'instance_key'):
- self._thread_local.instance_key = self._instance_key_start
- return self._thread_local
+ if not hasattr(_thread_local, 'instance_key'):
+ _thread_local.instance_key = self._instance_key_start
+ return _thread_local
def get_group_key(self, devices):
"""Returns a group key for the set of devices.
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
index 5348512..157618f 100644
--- a/tensorflow/contrib/distribute/python/estimator_training_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -26,21 +26,12 @@
import threading
from absl.testing import parameterized
import numpy as np
-import six
-_portpicker_import_error = None
-try:
- import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error: # pylint: disable=invalid-name
- _portpicker_import_error = _error
- portpicker = None
-
-# pylint: disable=g-import-not-at-top
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.optimizer_v2 import adagrad
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import estimator_training as dc_training
@@ -57,7 +48,6 @@
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import server_lib
BATCH_SIZE = 10
LABEL_DIMENSION = 2
@@ -73,130 +63,38 @@
WORKER = dc._TaskType.WORKER
PS = dc._TaskType.PS
-original_run_distribute_coordinator = dc.run_distribute_coordinator
+original_run_std_server = dc._run_std_server
-# TODO(yuefengz): merge this method back to test_util.
-def _create_local_cluster(num_workers,
- num_ps,
- has_eval=False,
- protocol="grpc",
- worker_config=None,
- ps_config=None):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+class MockOsEnv(dict):
- cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
- }
- if has_eval:
- cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
+ def __init__(self, *args):
+ self._thread_local = threading.local()
+ super(MockOsEnv, self).__init__(*args)
- cs = server_lib.ClusterSpec(cluster_dict)
+ def get(self, key, default):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.get(self._thread_local.dict, key, default)
+ else:
+ return dict.get(self, key, default)
- workers = [
- server_lib.Server(
- cs,
- job_name="worker",
- protocol=protocol,
- task_index=ix,
- config=worker_config,
- start=True) for ix in range(num_workers)
- ]
- ps_servers = [
- server_lib.Server(
- cs,
- job_name="ps",
- protocol=protocol,
- task_index=ix,
- config=ps_config,
- start=True) for ix in range(num_ps)
- ]
- if has_eval:
- evals = [
- server_lib.Server(
- cs,
- job_name="evaluator",
- protocol=protocol,
- task_index=0,
- config=worker_config,
- start=True)
- ]
- else:
- evals = []
+ def __getitem__(self, key):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__getitem__(self._thread_local.dict, key)
+ else:
+ return dict.__getitem__(self, key)
- return workers, ps_servers, evals
-
-
-def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
- """Create an in-process cluster that consists of only standard server."""
- # Leave some memory for cuda runtime.
- if has_eval:
- gpu_mem_frac = 0.7 / (num_workers + 1)
- else:
- gpu_mem_frac = 0.7 / num_workers
-
- worker_config = config_pb2.ConfigProto()
- worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
-
- # Enable collective ops which has no impact on non-collective ops.
- # TODO(yuefengz, tucker): removing this after we move the initialization of
- # collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- "/job:worker/replica:0/task:0")
-
- ps_config = config_pb2.ConfigProto()
- ps_config.device_count["GPU"] = 0
-
- return _create_local_cluster(
- num_workers,
- num_ps=num_ps,
- has_eval=has_eval,
- worker_config=worker_config,
- ps_config=ps_config,
- protocol="grpc")
-
-
-def _create_cluster_spec(has_chief=False,
- num_workers=1,
- num_ps=0,
- has_eval=False):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
- cluster_spec = {}
- if has_chief:
- cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
- if num_workers:
- cluster_spec[WORKER] = [
- "localhost:%s" % portpicker.pick_unused_port()
- for _ in range(num_workers)
- ]
- if num_ps:
- cluster_spec[PS] = [
- "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
- ]
- if has_eval:
- cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
- return cluster_spec
-
-
-def _bytes_to_str(maybe_bytes):
- if isinstance(maybe_bytes, six.string_types):
- return maybe_bytes
- else:
- return str(maybe_bytes, "utf-8")
-
-
-def _strip_protocol(target):
- # cluster_spec expects "host:port" strings.
- if "//" in target:
- return target.split("//")[1]
- else:
- return target
+ def __setitem__(self, key, val):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__setitem__(self._thread_local.dict, key, val)
+ else:
+ return dict.__setitem__(self, key, val)
class DistributeCoordinatorIntegrationTest(test.TestCase,
@@ -205,22 +103,20 @@
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=2, has_eval=True)
- cls._cluster_spec = {
- "worker": [
- _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
- ],
- "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
- "evaluator": [
- _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
- ]
- }
def setUp(self):
self._model_dir = tempfile.mkdtemp()
- self._event = threading.Event()
+ self._mock_os_env = MockOsEnv()
+ self._mock_context = test.mock.patch.object(os, "environ",
+ self._mock_os_env)
super(DistributeCoordinatorIntegrationTest, self).setUp()
+ self._mock_context.__enter__()
+
+ def tearDown(self):
+ self._mock_context.__exit__(None, None, None)
+ super(DistributeCoordinatorIntegrationTest, self).tearDown()
def dataset_input_fn(self, x, y, batch_size, shuffle):
@@ -391,43 +287,17 @@
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
self._inspect_train_and_eval_events(estimator)
- def _mock_run_distribute_coordinator(
- self,
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=dc.CoordinatorMode.STANDALONE_CLIENT,
- cluster_spec=None,
- session_config=None):
- # Calls the origial `run_distribute_coordinator` method but gets task config
- # from environment variables and then signals the caller.
- task_type = None
- task_id = None
- if not cluster_spec:
- cluster_spec = None
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- if not cluster_spec:
- cluster_spec = tf_config.get("cluster", {})
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", task_type)
- task_id = int(task_env.get("index", task_id))
- self._event.set()
- original_run_distribute_coordinator(
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=mode,
- cluster_spec=cluster_spec,
- task_type=task_type,
- task_id=task_id,
- session_config=session_config)
+ def _mock_run_std_server(self, *args, **kwargs):
+ ret = original_run_std_server(*args, **kwargs)
+ # Wait for all std servers to be brought up in order to reduce the chance of
+ # remote sessions taking local ports that have been assigned to std servers.
+ self._barrier.wait()
+ return ret
- def _task_thread(self, train_distribute, eval_distribute):
- with test.mock.patch.object(dc, "run_distribute_coordinator",
- self._mock_run_distribute_coordinator):
+ def _task_thread(self, train_distribute, eval_distribute, tf_config):
+ os.environ["TF_CONFIG"] = json.dumps(tf_config)
+ with test.mock.patch.object(dc, "_run_std_server",
+ self._mock_run_std_server):
self._complete_flow(train_distribute, eval_distribute)
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
@@ -448,13 +318,10 @@
"index": task_id
}
}
- self._event.clear()
t = threading.Thread(
- target=self._task_thread, args=(train_distribute, eval_distribute))
- with test.mock.patch.dict("os.environ",
- {"TF_CONFIG": json.dumps(tf_config)}):
- t.start()
- self._event.wait()
+ target=self._task_thread,
+ args=(train_distribute, eval_distribute, tf_config))
+ t.start()
return t
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
@@ -489,7 +356,11 @@
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=2, has_eval=True)
+ # 3 workers, 2 ps and 1 evaluator.
+ self._barrier = dc._Barrier(6)
+
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
for task_type, ts in threads.items():
@@ -516,7 +387,10 @@
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=0, has_eval=True)
+ # 3 workers and 1 evaluator.
+ self._barrier = dc._Barrier(4)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
threads[WORKER][0].join()
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 5f35e38..8165a70 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -732,14 +732,22 @@
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
+
+ # Train and predict datasets are created with the same input numpy arrays.
x_train = np.random.rand(num_samples, 1)
y_train = 3 * x_train
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
+ # The model is built once and the initial weights are saved.
+ # This is used to initialize the model for both the distribution and
+ # non-distribution run.
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ initial_weights = model.get_weights()
+
def fit_and_predict(with_distribution=None):
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.set_weights(initial_weights)
model.compile(
loss=keras.losses.mean_squared_error,
optimizer=gradient_descent.GradientDescentOptimizer(0.5),
@@ -751,12 +759,14 @@
train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
- # Running only 100 steps instead of the full dataset to keep test
- # duration small.
- model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+ # We have initialized the model to the same weight for the distribution
+ # and non-distribution run. If you want to initialize the model to
+ # random weights for each run, you need to run the model through the
+ # entire dataset at least once to ensure that the weights converge to
+ # the same value.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
weights = model.get_weights()
-
x_predict = [[1.], [2.], [3.], [4.]]
predict_batch_size = 4
if with_distribution:
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 18b4503..9f92ba7 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -36,9 +36,29 @@
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
+def pick_unused_port():
+ """Returns an unused and unassigned local port."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ global ASSIGNED_PORTS
+ with lock:
+ while True:
+ port = portpicker.pick_unused_port()
+ if port > 10000 and port not in ASSIGNED_PORTS:
+ ASSIGNED_PORTS.add(port)
+ logging.info('Using local port %r', port)
+ return port
+
+
def _create_cluster(num_workers,
num_ps,
has_chief=False,
@@ -49,8 +69,8 @@
"""Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ worker_ports = [pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
@@ -58,9 +78,9 @@
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
- cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
if has_chief:
- cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
@@ -139,11 +159,36 @@
num_workers,
num_ps=num_ps,
has_chief=has_chief,
+ has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
+def create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ """Create a cluster spec with tasks with unused local ports."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
+ if num_workers:
+ cluster_spec['worker'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec['ps'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
+ return cluster_spec
+
+
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index 8dad80a..c32ea9a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -93,12 +93,12 @@
bijector.inverse_log_det_jacobian(y, event_ndims=1)))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.linspace(-0.99, 0.99, 100).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index f073f51..9b9b3ce 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -212,7 +212,7 @@
def testStrWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
- ("tf.distributions.Normal("
+ ("tfp.distributions.Normal("
"\"Normal/\", "
"batch_shape=(), "
"event_shape=(), "
@@ -221,7 +221,7 @@
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
- ("tf.distributions.Chi2("
+ ("tfp.distributions.Chi2("
"\"silly/\", " # What a silly name that is!
"batch_shape=(2,), "
"event_shape=(), "
@@ -230,7 +230,7 @@
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
- ("tf.distributions.Exponential(\"Exponential/\", "
+ ("tfp.distributions.Exponential(\"Exponential/\", "
# No batch shape.
"event_shape=(), "
"dtype=float32)"),
@@ -240,7 +240,7 @@
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
- ("tf.distributions.MultivariateNormalDiag("
+ ("tfp.distributions.MultivariateNormalDiag("
"\"MVN/\", "
"batch_shape=(2,), "
"event_shape=(2,), "
@@ -251,7 +251,7 @@
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
- ("tf.distributions.MultivariateNormalDiag("
+ ("tfp.distributions.MultivariateNormalDiag("
"\"MVN2/\", "
"batch_shape=(?,), " # Partially known.
"event_shape=(3,), "
@@ -261,7 +261,7 @@
def testReprWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
- ("<tf.distributions.Normal"
+ ("<tfp.distributions.Normal"
" 'Normal/'"
" batch_shape=()"
" event_shape=()"
@@ -270,7 +270,7 @@
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
- ("<tf.distributions.Chi2"
+ ("<tfp.distributions.Chi2"
" 'silly/'" # What a silly name that is!
" batch_shape=(2,)"
" event_shape=()"
@@ -279,7 +279,7 @@
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
- ("<tf.distributions.Exponential"
+ ("<tfp.distributions.Exponential"
" 'Exponential/'"
" batch_shape=<unknown>"
" event_shape=()"
@@ -290,7 +290,7 @@
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
- ("<tf.distributions.MultivariateNormalDiag"
+ ("<tfp.distributions.MultivariateNormalDiag"
" 'MVN/'"
" batch_shape=(2,)"
" event_shape=(2,)"
@@ -301,7 +301,7 @@
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
- ("<tf.distributions.MultivariateNormalDiag"
+ ("<tfp.distributions.MultivariateNormalDiag"
" 'MVN2/'"
" batch_shape=(?,)" # Partially known.
" event_shape=(3,)"
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index bb9b804..3ba1c3a 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -65,13 +65,14 @@
```
where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn`
- constructs a `tf.distributions.Distribution`-like instance, and `x0` is a
+ constructs a `tfp.distributions.Distribution`-like instance, and `x0` is a
fixed initializing `Tensor`.
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
def normal_fn(self, event_size):
n = event_size * (event_size + 1) / 2
@@ -127,7 +128,7 @@
Args:
distribution_fn: Python `callable` which constructs a
- `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
+ `tfp.distributions.Distribution`-like instance from a `Tensor` (e.g.,
`sample0`). The function must respect the "autoregressive property",
i.e., there exists a permutation of event such that each coordinate is a
diffeomorphic function of on preceding coordinates.
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 519077b..612376e 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -45,7 +45,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
dtype = np.float32
dims = 2
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index 296e66f..3b3d8ee 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -61,8 +61,8 @@
`shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves
this property by zeroing out weights in its `masked_dense` layers.
- In the `tf.distributions` framework, a "normalizing flow" is implemented as a
- `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression"
+ In the `tfp` framework, a "normalizing flow" is implemented as a
+ `tfp.bijectors.Bijector`. The `forward` "autoregression"
is implemented using a `tf.while_loop` and a deep neural network (DNN) with
masked weights such that the autoregressive property is automatically met in
the `inverse`.
@@ -126,8 +126,9 @@
#### Examples
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
dims = 5
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
index f182a1a..178c3c9 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
@@ -41,9 +41,10 @@
"""Permutes the rightmost dimension of a `Tensor`.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfb = tfp.bijectors
- reverse = tfd.bijectors.Permute(permutation=[2, 1, 0])
+ reverse = tfb.Permute(permutation=[2, 1, 0])
reverse.forward([-1., 0., 1.])
# ==> [1., 0., -1]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
index 773ae24..0bcb08c 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -90,8 +90,9 @@
#### Example Use
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
index c828222..71ac290 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
@@ -80,9 +80,10 @@
Example usage:
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfb = tfp.bijectors
- r = tfd.bijectors.Reshape(event_shape_out=[1, -1])
+ r = tfb.Reshape(event_shape_out=[1, -1])
r.forward([3., 4.]) # shape [2]
# ==> [[3., 4.]] # shape [1, 2]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
index 6fbe866..0a6d690 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -42,7 +42,10 @@
#### Examples
```python
- tfb = tf.contrib.distributions.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
+
b = tfb.ScaleTriL(
diag_bijector=tfb.Exp(),
diag_shift=None)
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index cb5223b..c461833 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -63,7 +63,8 @@
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Cauchy distribution.
dist = tfd.Cauchy(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index affc64a..507c5d3 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -198,8 +198,11 @@
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single Deterministic supported at zero.
- constant = tf.contrib.distributions.Deterministic(0.)
+ constant = tfd.Deterministic(0.)
constant.prob(0.)
==> 1.
constant.prob(2.)
@@ -208,7 +211,7 @@
# Initialize a [2, 2] batch of scalar constants.
loc = [[0., 1.], [2., 3.]]
x = [[0., 1.1], [1.99, 3.]]
- constant = tf.contrib.distributions.Deterministic(loc)
+ constant = tfd.Deterministic(loc)
constant.prob(x)
==> [[1., 0.], [0., 1.]]
```
@@ -310,7 +313,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
constant = tfd.Deterministic([0., 2.])
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index acdea4d..4b50df5 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -63,7 +63,8 @@
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Gumbel distribution.
dist = tfd.Gumbel(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index b02c403..f121637 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -66,15 +66,18 @@
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar HalfNormal distribution.
- dist = tf.contrib.distributions.HalfNormal(scale=3.0)
+ dist = tfd.HalfNormal(scale=3.0)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued HalfNormals.
# The first has scale 11.0, the second 22.0
- dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
+ dist = tfd.HalfNormal(scale=[11.0, 22.0])
# Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
# returning a length two tensor.
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index 0672702..e1cfff3 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -70,7 +70,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Make independent distribution from a 2-batch Normal.
ind = tfd.Independent(
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 70d050d..4526282 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -89,7 +89,9 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
dist = tfd.InverseGamma(concentration=3.0, rate=2.0)
dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 02e3bad..21c9b5a 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -61,7 +61,8 @@
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Logistic distribution.
dist = tfd.Logistic(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 3b7114e..52b67f2 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -50,7 +50,9 @@
```python
# Create a mixture of two Gaussians:
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
mix = 0.3
bimix_gauss = tfd.Mixture(
cat=tfd.Categorical(probs=[mix, 1.-mix]),
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 8ffee94..f4d394f 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -44,7 +44,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
### Create a mixture of two scalar Gaussians:
@@ -113,12 +114,12 @@
"""Construct a `MixtureSameFamily` distribution.
Args:
- mixture_distribution: `tf.distributions.Categorical`-like instance.
+ mixture_distribution: `tfp.distributions.Categorical`-like instance.
Manages the probability of selecting components. The number of
categories must match the rightmost batch dimension of the
`components_distribution`. Must have either scalar `batch_shape` or
`batch_shape` matching `components_distribution.batch_shape[:-1]`.
- components_distribution: `tf.distributions.Distribution`-like instance.
+ components_distribution: `tfp.distributions.Distribution`-like instance.
Right-most batch dimension indexes components.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index cd0c282..0b5b76b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -85,7 +85,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate Gaussian.
mvn = tfd.MultivariateNormalDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 74d9d04..8054608 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -87,7 +87,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
# `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index dbc4c1b..bcb4937 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -73,7 +73,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index efe5a6d..8fdc998 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -91,7 +91,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index c6a23e4..c21f70f 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -77,13 +77,14 @@
```
Trainable (batch) lower-triangular matrices can be created with
- `tf.contrib.distributions.matrix_diag_transform()` and/or
- `tf.contrib.distributions.fill_triangular()`
+ `tfp.distributions.matrix_diag_transform()` and/or
+ `tfp.distributions.fill_triangular()`
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 7a7ad1b..85683e3 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -220,7 +220,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Create two batches of PoissonLogNormalQuadratureCompounds, one with
# prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 18a0f75..134658d 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -196,8 +196,9 @@
parameter determining the unnormalized probability of that component.
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
net = wavenet(inputs)
loc, unconstrained_scale, logits = tf.split(net,
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index a9d0fb4..4b520b9 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -124,7 +124,7 @@
tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight)
distribution: `tf.Distribution`-like instance. Distribution that is
transformed to produce this distribution.
- Default is `tf.distributions.Normal(0., 1.)`.
+ Default is `tfp.distributions.Normal(0., 1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index c25e8c5..af22f48 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -30,27 +30,27 @@
`[0, 1]`. Then you might do this:
```python
-tfd = tf.contrib.distributions
+ from tensorflow_probability.python.distributions.internal import statistical_testing
-expected_mean = ...
-num_samples = 5000
-samples = ... draw 5000 samples from P
+ expected_mean = ...
+ num_samples = 5000
+ samples = ... draw 5000 samples from P
-# Check that the mean looks right
-check1 = tfd.assert_true_mean_equal_by_dkwm(
- samples, low=0., high=1., expected=expected_mean,
- false_fail_rate=1e-6)
+ # Check that the mean looks right
+ check1 = statistical_testing.assert_true_mean_equal_by_dkwm(
+ samples, low=0., high=1., expected=expected_mean,
+ false_fail_rate=1e-6)
-# Check that the difference in means detectable with 5000 samples is
-# small enough
-check2 = tf.assert_less(
- tfd.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=1.0,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- 0.01)
+ # Check that the difference in means detectable with 5000 samples is
+ # small enough
+ check2 = tf.assert_less(
+ statistical_testing.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=1.0,
+ false_fail_rate=1e-6, false_pass_rate=1e-6),
+ 0.01)
-# Be sure to execute both assertion ops
-sess.run([check1, check2])
+ # Be sure to execute both assertion ops
+ sess.run([check1, check2])
```
The second assertion is an instance of experiment design. It's a
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 3c8aae2..a3d1783 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -300,7 +300,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index 73356a3..36cbd71 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -90,7 +90,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 9a47b48..fd5bf9e 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -108,7 +108,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index e68ddc5..8cd4e12 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -102,7 +102,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorLaplace.
vla = tfd.VectorLaplaceDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index 3923161..67d2ccd 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -110,7 +110,8 @@
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate VectorLaplace with some desired covariance.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 49ffff2..da57d0c 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -152,7 +152,7 @@
broadcastable with `event_shape`.
distribution: `tf.Distribution`-like instance. Distribution from which `k`
iid samples are used as input to transformation `F`. Default is
- `tf.distributions.Normal(loc=0., scale=1.)`.
+ `tfp.distributions.Normal(loc=0., scale=1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index f289b39e..bad91a0 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -92,7 +92,8 @@
Extra leading dimensions, if provided, allow for batches.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate vector Student's t-distribution.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 49b9de0..ee2fc58 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -480,11 +480,14 @@
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single 3x3 Wishart with Cholesky factored scale matrix and 5
# degrees-of-freedom.(*)
df = 5
chol_scale = tf.cholesky(...) # Shape is [3, 3].
- dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+ dist = tfd.WishartCholesky(df=df, scale=chol_scale)
# Evaluate this on an observation in R^3, returning a scalar.
x = ... # A 3x3 positive definite matrix.
@@ -498,14 +501,14 @@
# Initialize two 3x3 Wisharts with Cholesky factored scale matrices.
df = [5, 4]
chol_scale = tf.cholesky(...) # Shape is [2, 3, 3].
- dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+ dist = tfd.WishartCholesky(df=df, scale=chol_scale)
# Evaluate this on four observations.
x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3].
dist.prob(x) # Shape is [2, 2].
# (*) - To efficiently create a trainable covariance matrix, see the example
- # in tf.contrib.distributions.matrix_diag_transform.
+ # in tfp.distributions.matrix_diag_transform.
```
"""
@@ -604,11 +607,14 @@
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single 3x3 Wishart with Full factored scale matrix and 5
# degrees-of-freedom.(*)
df = 5
scale = ... # Shape is [3, 3]; positive definite.
- dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+ dist = tfd.WishartFull(df=df, scale=scale)
# Evaluate this on an observation in R^3, returning a scalar.
x = ... # A 3x3 positive definite matrix.
@@ -622,14 +628,14 @@
# Initialize two 3x3 Wisharts with Full factored scale matrices.
df = [5, 4]
scale = ... # Shape is [2, 3, 3].
- dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+ dist = tfd.WishartFull(df=df, scale=scale)
# Evaluate this on four observations.
x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3]; xi is positive definite.
dist.prob(x) # Shape is [2, 2].
# (*) - To efficiently create a trainable covariance matrix, see the example
- # in tf.contrib.distributions.matrix_diag_transform.
+ # in tfd.matrix_diag_transform.
```
"""
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index 86d2034..4bd2769 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -44,7 +44,6 @@
For an introduction to eager execution in TensorFlow, see:
-- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md))
-- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
-- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
-- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
+- [User Guide](https://www.tensorflow.org/guide/eager) ([source](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/index.md))
+- Notebook: [Basic Usage](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb)
+- Notebook: [Automatic differentiation and gradient tape](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb)
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 7ed77bc..11f60c8 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -20,6 +20,7 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.canned import head as head_lib
def _validate_input_fn_and_repeat_dataset(train_input_fn):
@@ -33,7 +34,18 @@
return _input_fn
-class _BoostedTreesEstimator(estimator.Estimator):
+# pylint: disable=protected-access
+def _is_classification_head(head):
+ """Infers if the head is a classification head."""
+ # Check using all classification heads defined in canned/head.py. However, it
+ # is not a complete list - it does not check for other classification heads
+ # not defined in the head library.
+ return isinstance(head,
+ (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
+ head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+
+
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
"""An Estimator for Tensorflow Boosted Trees models."""
def __init__(self,
@@ -96,8 +108,10 @@
negative gain). For pre and post pruning, you MUST provide
tree_complexity >0.
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
"""
- # pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
@@ -115,8 +129,14 @@
config=config)
super(_BoostedTreesEstimator, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
- # pylint:enable=protected-access
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=_is_classification_head(head))
+ # pylint: enable=protected-access
def boosted_trees_classifier_train_in_memory(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index b1581f3..e23d9c0 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -360,5 +360,79 @@
[pred['predictions'] for pred in predictions])
+class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testContribEstimatorThatDFCIsInPredictions(self):
+ # pylint:disable=protected-access
+ head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ head=head,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+ # pylint:enable=protected-access
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index 66c46e6..49f7bbd 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -53,6 +53,7 @@
```
Current limitations of this approach are:
+
* It doesn't support multi-node distributed mode.
* It doesn't support saveable objects other than variables (such as boosted
tree support)
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 0185ef6..e47342b 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -265,7 +265,7 @@
tensors = []
for (data_format, use_gpu) in GetTestConfigs():
tensors.append(_SetupVal(data_format, use_gpu))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = sess.run(tensors)
for i in range(1, len(values)):
self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
@@ -282,7 +282,7 @@
data_format, filter_format, dtype)
tensors.append(result)
ref_tensors.append(expected)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = sess.run(tensors)
ref_values = sess.run(ref_tensors)
for i in range(len(tensors)):
@@ -493,7 +493,7 @@
if gpu_only and not test.is_gpu_available():
tf_logging.info("Skipping OpEdgeCases tests.")
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Illegal strides.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Convolutional strides are not supported in "
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index d389748..8bc4db8 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -773,9 +773,9 @@
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
- predicted_distributions: A list of tf.Distributions. Predicted by the
- recognizer, and used to evaluate the likelihood of the structured noise.
- List length should match `structured_generator_inputs`.
+ predicted_distributions: A list of `tfp.distributions.Distribution`s.
+ Predicted by the recognizer, and used to evaluate the likelihood of the
+ structured noise. List length should match `structured_generator_inputs`.
weights: Optional `Tensor` whose rank is either 0, or the same dimensions as
`structured_generator_inputs`.
scope: The scope for the operations performed in computing the loss.
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index a462b68..b9ac1bf 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -91,9 +91,9 @@
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
- predicted_distributions: A list of tf.Distributions. Predicted by the
- recognizer, and used to evaluate the likelihood of the structured noise.
- List length should match `structured_generator_inputs`.
+ predicted_distributions: A list of `tfp.distributions.Distribution`s.
+ Predicted by the recognizer, and used to evaluate the likelihood of the
+ structured noise. List length should match `structured_generator_inputs`.
discriminator_and_aux_fn: The original discriminator function that returns
a tuple of (logits, `predicted_distributions`).
"""
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 58f3480..64d6706 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -399,7 +399,7 @@
target_tensor = train._generate_stargan_random_domain_target(
batch_size, domain_numbers)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
targets = sess.run(target_tensor)
self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
for target in targets:
@@ -676,7 +676,7 @@
self.assertIsInstance(model_loss, namedtuples.GANLoss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index fed8a77..27aed09 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -233,7 +233,7 @@
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -261,7 +261,7 @@
"""
def testGrid2BasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -292,7 +292,7 @@
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -323,7 +323,7 @@
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -348,7 +348,7 @@
"""
def testGrid1LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
@@ -410,7 +410,7 @@
"""
def testGrid3LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -455,7 +455,7 @@
"""
def testGridRNNEdgeCasesLikeRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
@@ -481,7 +481,7 @@
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -541,7 +541,7 @@
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -581,7 +581,7 @@
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -623,7 +623,7 @@
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -663,7 +663,7 @@
self.assertEqual(out[0].get_shape(), (3, num_units))
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -700,7 +700,7 @@
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size))
@@ -715,7 +715,7 @@
def testGrid2LSTMCellLegacy(self):
"""Test for legacy case (when state_is_tuple=False)."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index 9ed0175..f44edaa 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -29,7 +29,7 @@
class InputPipelineOpsTest(test.TestCase):
def testObtainNext(self):
- with self.test_session():
+ with self.cached_session():
var = state_ops.variable_op([], dtypes.int64)
state_ops.assign(var, -1).op.run()
c = constant_op.constant(["a", "b"])
@@ -45,7 +45,7 @@
def testSeekNext(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list)
session.run([variables.global_variables_initializer()])
self.assertEqual(b"a", session.run(elem))
@@ -65,7 +65,7 @@
def testSeekNextLimitEpochs(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
session.run([
variables.local_variables_initializer(),
@@ -75,7 +75,7 @@
def testSeekNextLimitEpochsThree(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
session.run([
variables.local_variables_initializer(),
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 7250753..4d5cc24 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -32,7 +32,7 @@
def testInvalidLogitsShape(self):
"""An error is raised when logits have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2,))
labels = constant_op.constant([0, 1])
with self.assertRaises(ValueError):
@@ -40,7 +40,7 @@
def testInvalidLabelsShape(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(1, 1, 2))
with self.assertRaises(ValueError):
@@ -48,7 +48,7 @@
def testInvalidWeightsShape(self):
"""An error is raised when weights have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2,))
weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
@@ -57,7 +57,7 @@
def testInvalidLabelsDtype(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], dtype=dtypes.float32)
with self.assertRaises(ValueError):
@@ -65,7 +65,7 @@
def testNoneWeightRaisesValueError(self):
"""An error is raised when weights are None."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0])
with self.assertRaises(ValueError):
@@ -73,7 +73,7 @@
def testInconsistentLabelsAndWeightsShapesSameRank(self):
"""Error raised when weights and labels have same ranks, different sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
@@ -82,7 +82,7 @@
def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
"""Error raised when weights and labels have different ranks and sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2, 1))
weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
@@ -91,7 +91,7 @@
def testOutOfRangeLabels(self):
"""An error is raised when labels are not in [0, num_classes)."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([1, 0, 4])
@@ -101,7 +101,7 @@
def testZeroLossInt32Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@@ -110,7 +110,7 @@
def testZeroLossInt64Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
[-0.5, 0.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
@@ -130,7 +130,7 @@
]
for batch_size, num_classes in logits_shapes:
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(
dtypes.float32, shape=(batch_size, num_classes))
labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
@@ -140,7 +140,7 @@
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
[1.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1])
@@ -150,7 +150,7 @@
def testIncorrectPredictions(self):
"""Loss is >0 when an incorrect class has higher logits than true class."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
[0.5, -1.8, 2.0]])
labels = constant_op.constant([1, 0, 2])
@@ -162,7 +162,7 @@
def testIncorrectPredictionsColumnLabels(self):
"""Same as above but labels is a rank-2 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -174,7 +174,7 @@
def testIncorrectPredictionsZeroWeights(self):
"""Loss is 0 when all weights are missing even if predictions are wrong."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -185,7 +185,7 @@
def testNonZeroLossWithPythonScalarWeights(self):
"""Weighted loss is correctly computed when weights is a python scalar."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -195,7 +195,7 @@
def testNonZeroLossWithScalarTensorWeights(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -205,7 +205,7 @@
def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -216,7 +216,7 @@
def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
labels = constant_op.constant([1, 0, 2, 1])
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 2ff4d41..bad0a59 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -58,7 +58,7 @@
def testInvalidInputShape(self):
x = constant_op.constant([[2.0, 1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10)
with self.assertRaisesWithPredicateMatch(
dense_kernel_mapper.InvalidShapeError,
@@ -70,7 +70,7 @@
x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
[4.0, -2.0, -1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10, 1.0)
mapped_x1 = rffm.map(x1)
mapped_x2 = rffm.map(x2)
@@ -80,7 +80,7 @@
def testSameOmegaReused(self):
x = constant_op.constant([[2.0, 1.0, 0.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 100)
mapped_x = rffm.map(x)
mapped_x_copy = rffm.map(x)
@@ -93,7 +93,7 @@
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
@@ -113,7 +113,7 @@
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
@@ -139,7 +139,7 @@
normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
total_absolute_error = 0.0
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
# Cache mappings so that they are not computed multiple times.
cached_mappings = dict((point, rffm.map(point))
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
index 7289b45..bf89922 100644
--- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -64,7 +64,7 @@
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 1.
sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
for i in range(10):
@@ -108,7 +108,7 @@
get_next = iterator.get_next()
data = list()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 2.
sess.run(
init_op, feed_dict={
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 2f33a2b..0e5ea6b 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -47,7 +47,7 @@
class Seq2SeqTest(test.TestCase):
def testRNNDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -65,7 +65,7 @@
self.assertEqual((2, 2), res[0].shape)
def testBasicRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -81,7 +81,7 @@
self.assertEqual((2, 2), res[0].shape)
def testTiedRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -98,7 +98,7 @@
self.assertEqual((2, 2), res[0].shape)
def testEmbeddingRNNDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -124,7 +124,7 @@
self.assertEqual((2, 2), res[0].h.shape)
def testEmbeddingRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -228,7 +228,7 @@
self.assertAllClose(res1, res3)
def testEmbeddingTiedRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -316,7 +316,7 @@
self.assertAllClose(res1, res3)
def testAttentionDecoder1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -341,7 +341,7 @@
self.assertEqual((2, 2), res[0].shape)
def testAttentionDecoder2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -367,7 +367,7 @@
self.assertEqual((2, 2), res[0].shape)
def testDynamicAttentionDecoder1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -391,7 +391,7 @@
self.assertEqual((2, 2), res[0].shape)
def testDynamicAttentionDecoder2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -416,7 +416,7 @@
self.assertEqual((2, 2), res[0].shape)
def testAttentionDecoderStateIsTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
@@ -448,7 +448,7 @@
self.assertEqual((2, 2), res[0][1].h.shape)
def testDynamicAttentionDecoderStateIsTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
@@ -479,7 +479,7 @@
self.assertEqual((2, 2), res[0][1].h.shape)
def testEmbeddingAttentionDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -513,7 +513,7 @@
self.assertEqual((2, 2), res[0].shape)
def testEmbeddingAttentionSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -622,7 +622,7 @@
# self.assertAllClose(res1, res3)
def testOne2ManyRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -712,7 +712,7 @@
self.assertAllClose(res1, res3)
def testSequenceLoss(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
targets = [
constant_op.constant(
@@ -748,7 +748,7 @@
self.assertAllClose(9.656628, res)
def testSequenceLossByExample(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_classes = 5
logits = [
constant_op.constant(
@@ -778,7 +778,7 @@
# classes = 10
# buckets = [(4, 4), (8, 8)]
- # with self.test_session():
+ # with self.cached_session():
# # Here comes a sample Seq2Seq model using GRU cells.
# def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
# """Example sequence-to-sequence model that uses GRU cells."""
@@ -839,7 +839,7 @@
random.seed(111)
np.random.seed(111)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# We use sampled softmax so we keep output projection separate.
w = variable_scope.get_variable("proj_w", [24, classes])
w_t = array_ops.transpose(w)
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index a676b70..a4b3d83 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -4,5 +4,5 @@
devices. It enables low-latency inference of on-device machine learning models
with a small binary size and fast performance supporting hardware acceleration.
-See the documentation: https://www.tensorflow.org/mobile/tflite/
-Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite)
+See the documentation: https://www.tensorflow.org/lite/
+Documentation edits can be made here: [tensorflow/contrib/lite/g3doc](./g3doc/)
diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md
index cbdeeac..dc31171 100644
--- a/tensorflow/contrib/lite/examples/android/app/README.md
+++ b/tensorflow/contrib/lite/examples/android/app/README.md
@@ -2,7 +2,7 @@
## Building from Source with Bazel
-1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source).
+1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/lite/demo_android).
2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device:
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 835fc25..52e7161 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -1,5 +1,12 @@
package(default_visibility = ["//visibility:private"])
+package_group(
+ name = "experimental",
+ packages = [
+ "//tensorflow/contrib/lite/experimental/...",
+ ],
+)
+
licenses(["notice"]) # Apache 2.0
load(
@@ -51,6 +58,9 @@
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
copts = tflite_copts(),
+ visibility = [
+ ":experimental",
+ ],
deps = [
":c_api_internal",
"//tensorflow/contrib/lite:context",
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 1dffe30..beaa5c4 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -5,7 +5,7 @@
# Dropdown menu
- name: Ecosystem
path: /ecosystem
- is_default: True
+ is_default: true
menu:
- include: /ecosystem/_menu_toc.yaml
lower_tabs:
@@ -14,46 +14,49 @@
- name: Guide
contents:
- title: Overview
- path: /mobile/overview
- - title: Developer Guide
- path: /mobile/devguide
- - title: Android Demo App
- path: /mobile/demo_android
- - title: iOS Demo App
- path: /mobile/demo_ios
+ path: /lite/overview
+ - title: Developer guide
+ path: /lite/devguide
+ - title: Android demo app
+ path: /lite/demo_android
+ - title: iOS demo app
+ path: /lite/demo_ios
- title: Performance
- path: /mobile/performance
- - break: True
+ path: /lite/performance
+ - break: true
- title: TensorFlow Lite APIs
- path: /mobile/apis
+ path: /lite/apis
- title: Custom operators
- path: /mobile/custom_operators
- - title: TensorFlow Lite Ops Versioning
- path: /mobile/ops_versioning
- - title: TensorFlow Lite Compatibility Guide
- path: /mobile/tf_ops_compatibility
- - title: List of Hosted Models
- path: /mobile/models
+ path: /lite/custom_operators
+ - title: TensorFlow Lite ops versioning
+ path: /lite/ops_versioning
+ - title: TensorFlow Lite compatibility guide
+ path: /lite/tf_ops_compatibility
+ - title: List of hosted models
+ path: /lite/models
- title: TensorFlow Lite for iOS
- path: /mobile/ios
+ path: /lite/ios
- title: TensorFlow Lite for Raspberry Pi
- path: /mobile/rpi
+ path: /lite/rpi
- - heading: TF Mobile
+ - title: TF Mobile
+ style: accordion
status: deprecated
- - title: Overview
- path: /mobile/tfmobile/
- - title: Building TensorFlow on Android
- path: /mobile/tfmobile/android_build
- - title: Building TensorFlow on IOS
- path: /mobile/tfmobile/ios_build
- - title: Integrating TensorFlow libraries
- path: /mobile/tfmobile/linking_libs
- - title: Preparing models for mobile deployment
- path: /mobile/tfmobile/prepare_models
- - title: Optimizing for mobile
- path: /mobile/tfmobile/optimizing
+ section:
+ - title: Overview
+ path: /lite/tfmobile/
+ - title: Building TensorFlow on Android
+ path: /lite/tfmobile/android_build
+ - title: Building TensorFlow on IOS
+ path: /lite/tfmobile/ios_build
+ - title: Integrating TensorFlow libraries
+ path: /lite/tfmobile/linking_libs
+ - title: Preparing models for mobile deployment
+ path: /lite/tfmobile/prepare_models
+ - title: Optimizing for mobile
+ path: /lite/tfmobile/optimizing
- name: API
contents:
- - include: /mobile/api_docs/python/_toc.yaml
+ - title: API
+ path: /api_docs/python/tf/contrib/lite
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index b3f21e2..bc66cc5 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -1,60 +1,209 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
+project_path: /lite/_project.yaml
+book_path: /lite/_book.yaml
description: <!--no description-->
landing_page:
+ custom_css_path: /site-assets/css/style.css
rows:
- - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
+ - heading: TensorFlow Lite is for mobile and embedded devices.
+ description: >
+ <p style="max-width: 75%;">
+ TensorFlow Lite is the official solution for running machine learning
+ models on mobile and embedded devices. It enables on‑device machine
+ learning inference with low latency and a small binary size on Android,
+ iOS, and other operating systems.
+ </p>
+ <style>
+ .tfo-landing-row-heading {
+ padding-top: 0 !important;
+ }
+ .tfo-landing-row-heading h2 {
+ margin-top: 0 !important;
+ }
+ .tfo-landing-row-heading-list ol, .tfo-landing-row-heading-list ul {
+ margin-top: 0;
+ }
+ </style>
+
+ - classname: tfo-landing-row-heading tfo-landing-row-heading-list
+ heading: Many benefits
+ description: >
+ On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these:
items:
- - classname: devsite-landing-row-50
+ - list:
+ - heading: Performance
+ description: >
+ TF Lite is fast with no noticeable accuracy loss—see the <a href="./performance">metrics</a>.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Portability
+ description: >
+ <a href="https://developer.android.com/ndk/guides/neuralnetworks/" class="external">Android</a>,
+ iOS, and more specialized IoT devices.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - list:
+ - heading: Low latency
+ description: >
+ Optimized float- and fixed-point CPU kernels, op‑fusing, and more.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Acceleration
+ description: >
+ Integration with GPU and internal/external accelerators.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - list:
+ - heading: Small model size
+ description: >
+ Controlled dependencies, <a href="https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3" class="external">quantization</a>,
+ and op registration.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Tooling
+ description: >
+ Conversion, compression, benchmarking, power-consumption, and more.
+ icon:
+ icon_name: lens
+ foreground: theme
+
+ - classname: devsite-landing-row-logos tfo-landing-row-heading
+ heading: Companies using TensorFlow Lite
+ items:
+ - custom_image:
+ path: ./images/landing-page/photos_logo.png
+ path: https://www.photos.google.com
+ - custom_image:
+ path: ./images/landing-page/gboard_logo.png
+ path: https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US
+ - custom_image:
+ path: ./images/landing-page/gmail_logo.png
+ path: https://www.google.com/gmail/
+ - custom_image:
+ path: ./images/landing-page/assistant_logo.png
+ path: https://assistant.google.com/
+
+ - classname: devsite-landing-row-logos
+ items:
+ - custom_image:
+ path: ./images/landing-page/vsco_logo.png
+ path: https://vsco.co
+ - custom_image:
+ path: ./images/landing-page/shazam_logo.png
+ path: https://www.shazam.com/
+ - custom_image:
+ path: ./images/landing-page/nest_logo.png
+ path: https://nest.com/
+ - custom_image:
+ path: ./images/landing-page/loseit_logo.png
+ path: https://www.loseit.com/
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ background: grey
+ items:
+ - description: >
+ <em>“TensorFlow Lite helped us introduce machine learning and AI into our
+ app in an easy and streamlined way. We could reduce the size of our
+ models while keeping the accuracy high. This helped us create an amazing
+ fishing experience for our users by allowing them to identify any fish
+ species with just a photo.”</em>
+ image_path: ./images/landing-page/fishbrain_logo_big.png
+
+ - heading: How it works
+ items:
+ - heading: Build
+ icon:
+ icon_name: build
description: >
- TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
- embedded devices. It enables on-device machine learning inference with
- low latency and a small binary size. TensorFlow Lite also supports
- hardware acceleration with the
- <a href='https://developer.android.com/ndk/guides/neuralnetworks/index.html'>Android Neural Networks API</a>.
- list:
- - heading: Key point 1
- description: >
- [high-level overview]
- icon:
- icon_name: chevron_right
- foreground: theme
- background: grey
- - heading: Key point 2
- description: >
- [high-level overview]
- icon:
- icon_name: chevron_right
- foreground: theme
- background: grey
- - heading: Key point 3
- description: >
- [high-level overview]
- icon:
- icon_name: chevron_right
- foreground: theme
- background: grey
- code_block: |
- <pre class = "prettyprint">
- $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
- --inference_type=FLOAT \
- --input_type=FLOAT \
- --input_arrays=input \
- --output_arrays=MobilenetV1/Predictions/Reshape_1 \
- --input_shapes=1,224,224,3
- </pre>
+ Build a new model or retrain an existing one, such as using transfer learning.
+ buttons:
+ - label: Read the developer guide
+ path: /lite/devguide
+ classname: button button-primary tfo-button-primary
+ - heading: Convert
+ icon:
+ icon_name: autorenew
+ description: >
+ Convert a TensorFlow model into a compressed flat buffer with the
+ TensorFlow Lite Optimizing Converter (TOCO).
+ buttons:
+ - label: Read the TOCO guide
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md
+ classname: button button-primary tfo-button-primary
+ - heading: Deploy
+ icon:
+ icon_name: bolt
+ description: >
+ Take the compressed <code>.tflite</code> file and load it into a mobile
+ or embedded device.<br/>
+ See the <a href="#build-your-first-tensorflow-lite-app">tutorials below</a> to build an app.
+
+ - heading: Build your first TensorFlow Lite app
+ background: grey
+ items:
+ - classname: tfo-landing-row-item-inset-white
+ heading: Get started
+ description: >
+ <ul>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/" class="external">TensorFlow for Poets</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/" class="external">TensorFlow for Poets 2: Android</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/" class="external">TensorFlow for Poets 2: iOS </a></li>
+ <li>Intermediate: <a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193" class="external">Object detection tutorial</a>
+ </ul>
+ - classname: tfo-landing-row-item-inset-white
+ heading: Share your TensorFlow Lite story
+ description: >
+ We love to hear what you're working on—it may even get highlighted on
+ our social media! <a href="https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss" class="external">Tell us</a>.
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ items:
+ - description: >
+ <p>
+ <em>“The release of TensorFlow Lite has allowed us to deploy an engaging
+ real-time experience to our users that eliminates the requirement
+ for a data connection. TensorFlow Lite’s ability to compress and
+ optimize the TensorFlow graph for mobile deployment has been
+ transformative in expanding the capabilities of Snap It.</em>
+ </p>
+ <p>
+ <em>Through TensorFlow Lite, our users can now enjoy a state of the
+ art, computer-vision-based food logging experience without worrying
+ about signal strength. We look forward to future collaborations
+ with the TensorFlow Lite team.”</em>
+ </p>
+ image_path: ./images/landing-page/loseit_logo_big.png
- classname: devsite-landing-row-cards
+ background: grey
+ heading: Updates
items:
+ - heading: Introducing the Model Optimization Toolkit
+ image_path: /ecosystem/images/tf-logo-card-16x9.png
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ buttons:
+ - label: Read on TensorFlow blog
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ - heading: East Africa Cassava App
+ image_path: ./images/landing-page/detect_crop_disease_in_africa.png
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
+ buttons:
+ - label: Read more
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
- heading: Using TensorFlow Lite on Android
image_path: /ecosystem/images/tf-logo-card-16x9.png
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
buttons:
- label: Read on TensorFlow blog
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
+
+ - classname: devsite-landing-row-cards
+ background: grey
+ items:
- heading: TensorFlow Lite at the Dev Summit
youtube_id: FAMfy7izB6A
buttons:
@@ -66,3 +215,4 @@
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
+ - classname: devsite-landing-row-item-hidden
diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml
index b396665..3ce6986 100644
--- a/tensorflow/contrib/lite/g3doc/_project.yaml
+++ b/tensorflow/contrib/lite/g3doc/_project.yaml
@@ -1,10 +1,10 @@
name: TensorFlow Lite
-breadcrumb_name: Mobile
-home_url: /mobile/
+breadcrumb_name: TensorFlow Lite
+home_url: /lite/
parent_project_metadata_path: /_project.yaml
description: >
TensorFlow Lite is a lightweight solution for mobile and embedded devices.
-use_site_branding: True
-hide_from_products_list: True
+use_site_branding: true
+hide_from_products_list: true
content_license: cc3-apache2
buganizer_id: 316308
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
deleted file mode 100644
index 1e1c44c..0000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-# Automatically generated file; please do not edit
-toc:
- - title: TensorFlow Lite
- section:
- - title: Overview
- path: /mobile/api_docs/python/
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index 90e7915..0eed516 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,4 @@
-
-# Developer Guide
+# TF Lite Developer Guide
Using a TensorFlow Lite model in your mobile app requires multiple
considerations: you must choose a pre-trained or custom model, convert the model
@@ -55,7 +54,7 @@
### Train a custom model
A developer may choose to train a custom model using Tensorflow (see the
-[TensorFlow tutorials](../../tutorials/) for examples of building and training
+[TensorFlow tutorials](../tutorials/) for examples of building and training
models). If you have already written a model, the first step is to export this
to a `tf.GraphDef` file. This is required because some formats do not store the
model structure outside the code, and we must communicate with other parts of the
@@ -205,7 +204,7 @@
[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
You can also download a
[prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
-See the <a href="../demo_android.md">Android demo</a> guide for details.
+See the <a href="./demo_android.md">Android demo</a> guide for details.
The <a href="./android_build.md">Android mobile</a> guide has instructions for
installing TensorFlow on Android and setting up `bazel` and Android Studio.
@@ -214,7 +213,7 @@
To integrate a TensorFlow model in an iOS app, see the
[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md)
-guide and <a href="../demo_ios.md">iOS demo</a> guide.
+guide and <a href="./demo_ios.md">iOS demo</a> guide.
#### Core ML support
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
new file mode 100644
index 0000000..ced0872
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
new file mode 100644
index 0000000..45b3b4f
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
new file mode 100644
index 0000000..bc1bf6e
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
new file mode 100644
index 0000000..d76fca8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
new file mode 100644
index 0000000..f1a93ab
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
new file mode 100644
index 0000000..21aa2c8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
new file mode 100644
index 0000000..b6b3d14
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
new file mode 100644
index 0000000..b3e46d4
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
new file mode 100644
index 0000000..35bfd97
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
new file mode 100644
index 0000000..4333426
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
new file mode 100644
index 0000000..6ec412c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
new file mode 100644
index 0000000..f408f90
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index c7cdee0..b0f32a8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -93,7 +93,7 @@
guide you through the basics here.
- First, follow our instructions for
- <a href="http://www.tensorflow.org/install/install_sources">installing from sources</a>.
+ <a href="http://www.tensorflow.org/install/source">installing from sources</a>.
This will also guide you through installing Bazel and cloning the
TensorFlow code.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index d003bb2..49ad35d 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -4,7 +4,7 @@
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
-<a href="../index.md">TensorFlow Lite</a>.
+<a href="../../lite">TensorFlow Lite</a>.
## TensorFlow Lite versus TensorFlow Mobile
diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md
index 2634934..df77bfa 100644
--- a/tensorflow/contrib/lite/java/ovic/README.md
+++ b/tensorflow/contrib/lite/java/ovic/README.md
@@ -4,7 +4,7 @@
## Pre-requisite
-Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
+Follow the steps [here](https://www.tensorflow.org/lite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
## Test the benchmarker:
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a6fd4ac..195474e 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -43,6 +43,7 @@
"compatibility.h",
"types.h",
],
+ deps = ["@com_google_absl//absl/base:core_headers"],
)
config_setting(
@@ -458,7 +459,7 @@
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
- "//tensorflow/contrib/lite/kernels:activation_functor",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
"@gemmlowp",
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index ee3fe78..f892b8f 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -24,6 +24,9 @@
namespace tflite {
namespace optimized_ops {
+// TODO(b/80418076): Move to legacy ops file, along with invocations.
+static constexpr int kDepthwiseReverseShift = -1;
+
// Implementation of quantized DepthwiseConv
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
@@ -1712,8 +1715,8 @@
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
#ifdef USE_NEON
- const bool shift_left = (output_shift <= 0);
- const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1;
+ const bool shift_left = (output_shift > 0);
+ const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
#endif
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
@@ -1872,7 +1875,7 @@
acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
}
for (int j = 0; j < 4; j++) {
- acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
}
} else {
// Fixed-point multiplication.
@@ -1916,8 +1919,8 @@
acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
// Rounding right shift.
- acc0 = RoundingDivideByPOT(acc0, output_shift);
- acc1 = RoundingDivideByPOT(acc1, output_shift);
+ acc0 = RoundingDivideByPOT(acc0, -output_shift);
+ acc1 = RoundingDivideByPOT(acc1, -output_shift);
} else {
// Fixed-point multiplication.
acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
@@ -1953,7 +1956,7 @@
// Fixed-point multiplication.
acc = vqrdmulhq_n_s32(acc, output_multiplier);
// Rounding right shift.
- acc = RoundingDivideByPOT(acc, output_shift);
+ acc = RoundingDivideByPOT(acc, -output_shift);
} else {
// Fixed-point multiplication.
acc = vmulq_n_s32(acc, multiplier_power_of_two);
@@ -1980,7 +1983,7 @@
for (; i < num_output_values; i++) {
int32 acc = acc_buffer[i];
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -2020,7 +2023,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index e14d04a..4809ddd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -49,7 +49,7 @@
int32 output_multiplier;
int32 output_activation_min;
int32 output_activation_max;
- int32 output_shift;
+ int32 output_right_shift;
int32 input_width;
int32 input_height;
int32 stride_width;
@@ -75,7 +75,7 @@
#define OFFSET_OUTPUT_MULTIPLIER 52
#define OFFSET_OUTPUT_ACTIVATION_MIN 56
#define OFFSET_OUTPUT_ACTIVATION_MAX 60
-#define OFFSET_OUTPUT_SHIFT 64
+#define OFFSET_OUTPUT_RIGHT_SHIFT 64
#define OFFSET_INPUT_WIDTH 68
#define OFFSET_INPUT_HEIGHT 72
#define OFFSET_STRIDE_WIDTH 76
@@ -105,8 +105,8 @@
OFFSET_OUTPUT_ACTIVATION_MIN, "");
static_assert(offsetof(DepthwiseConvParams, output_activation_max) ==
OFFSET_OUTPUT_ACTIVATION_MAX, "");
-static_assert(offsetof(DepthwiseConvParams, output_shift) ==
- OFFSET_OUTPUT_SHIFT, "");
+static_assert(offsetof(DepthwiseConvParams, output_right_shift) ==
+ OFFSET_OUTPUT_RIGHT_SHIFT, "");
static_assert(offsetof(DepthwiseConvParams, input_width) ==
OFFSET_INPUT_WIDTH, "");
static_assert(offsetof(DepthwiseConvParams, input_height) ==
@@ -189,7 +189,7 @@
"ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
"ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w9\n"
- "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v29.4s, w2\n"
"ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"dup v30.4s, w4\n"
@@ -1166,7 +1166,7 @@
// values from time to time when there are not enough NEON registers.
// We use x9--x15 general purpose registers as they are caller-saved
// temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
- "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
"cmp %w[output_window_height], #2\n"
"dup v28.8h, w0\n"
@@ -2216,7 +2216,7 @@
"dup v27.4s, w10\n"
"ld1 {v0.8b}, [%[filter_ptr]], #8\n"
"cmp x11, #16\n"
- "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w9\n"
"ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w10, w10\n"
@@ -2355,7 +2355,7 @@
"dup v26.8h, w6\n"
"ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w7\n"
- "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w6\n"
"ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w7, w7\n"
@@ -2532,7 +2532,7 @@
"dup v26.8h, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w13\n"
- "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w13, w13\n"
@@ -2739,7 +2739,7 @@
"dup v26.8h, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w13\n"
- "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w13, w13\n"
@@ -2910,7 +2910,7 @@
#undef OFFSET_OUTPUT_MULTIPLIER
#undef OFFSET_OUTPUT_ACTIVATION_MIN
#undef OFFSET_OUTPUT_ACTIVATION_MAX
-#undef OFFSET_OUTPUT_SHIFT
+#undef OFFSET_OUTPUT_RIGHT_SHIFT
#undef OFFSET_INPUT_WIDTH
#undef OFFSET_INPUT_HEIGHT
#undef OFFSET_OUTPUT_WIDTH
@@ -3194,7 +3194,7 @@
(stride_height == 1 || stride_height == 2) &&
(stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
(pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
- (input_depth % 8) == 0 && (output_shift > 0) &&
+ (input_depth % 8) == 0 && (output_shift <= 0) &&
dilation_width_factor == 1 && dilation_height_factor == 1;
if (!supported) {
@@ -3272,7 +3272,7 @@
params.output_offset = output_offset;
params.filter_offset = filter_offset;
params.output_multiplier = output_multiplier;
- params.output_shift = output_shift;
+ params.output_right_shift = -output_shift;
params.output_activation_min = output_activation_min;
params.output_activation_max = output_activation_max;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 6f4e135..1a2d451 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -88,7 +88,7 @@
// This will be phased out as the shifts are revised with more thought. Use of a
// constant enables us to track progress on this work.
//
-// Used mainly to convert from old-style shifts (right) to new-style (left).
+// Used to convert from old-style shifts (right) to new-style (left).
static constexpr int kReverseShift = -1;
// Make a local VectorMap typedef allowing to map a float array
@@ -977,7 +977,7 @@
const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
output_shape, output_dim_count - 1);
static constexpr int kPeel = 4;
- const bool shift_left = (output_shift <= 0);
+ const bool shift_left = (output_shift > 0);
for (int k = 0; k < input_size; k += 64) {
optimized_ops_preload_l1_stream(input_data + k);
}
@@ -1090,7 +1090,7 @@
bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
if (shift_left) {
- const int32 multiplier_power_of_two = 1 << -output_shift;
+ const int32 multiplier_power_of_two = 1 << output_shift;
reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
} else {
@@ -1098,7 +1098,7 @@
reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
// Rounding-shift-right.
using gemmlowp::RoundingDivideByPOT;
- reduced = RoundingDivideByPOT(reduced, output_shift);
+ reduced = RoundingDivideByPOT(reduced, -output_shift);
}
// Add the output offset.
const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
@@ -1195,7 +1195,7 @@
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, batches, output_rows);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -1219,7 +1219,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -1274,14 +1275,14 @@
if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
GEMVForLstmCellWithSymmetricRange(
input_shape, input_data, filter_shape, filter_data, bias_shape,
- bias_data_int32, output_multiplier, -output_shift, output_shape,
+ bias_data_int32, output_multiplier, output_shift, output_shape,
output_data);
return;
}
if (!(output_depth % 4) && !(accum_depth % 8)) {
GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
filter_offset, bias_shape, bias_data_int32,
- output_multiplier, -output_shift, output_shape,
+ output_multiplier, output_shift, output_shape,
output_data);
return;
}
@@ -1302,7 +1303,7 @@
scale_stage.result_offset_after_shift = 0;
scale_stage.result_fixedpoint_multiplier = output_multiplier;
// Note that this shift is negated wrt ordinary FC.
- scale_stage.result_exponent = -output_shift;
+ scale_stage.result_exponent = output_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -1330,7 +1331,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -1376,8 +1378,8 @@
#if defined USE_NEON
const int8* shuffled_weights_ptr = shuffled_weights_data;
if (batches == 1) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
// Accumulation loop.
int32x4_t row_accum0 = vdupq_n_s32(0);
@@ -1443,8 +1445,8 @@
vst1_s16(output_data + c, res16);
}
} else if (batches == 4) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
const int8* shuffled_input_ptr =
reinterpret_cast<const int8*>(shuffled_input_workspace_data);
@@ -1575,8 +1577,8 @@
// (16-bit, typically 3 integer bits) fixed-point format. The quantized
// multiplier and shift here have been pre-computed offline
// (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1627,7 +1629,7 @@
// quantized multiplier and shift here have been pre-computed offline
// (e.g. by toco).
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1818,7 +1820,8 @@
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
tflite::FullyConnectedParams op_params;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -2437,7 +2440,7 @@
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, output_cols);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -2471,7 +2474,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -2792,6 +2796,7 @@
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
@@ -3799,11 +3804,11 @@
uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label(
+ "LstmCell/quantized (8bit external, 16bit internal)");
int32 weights_zero_point = params.weights_zero_point;
int32 accum_multiplier = params.accum_multiplier;
int accum_shift = params.accum_shift;
- gemmlowp::ScopedProfilingLabel label(
- "LstmCell/quantized (8bit external, 16bit internal)");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
@@ -5018,7 +5023,7 @@
std::max(diff_min - 1, // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- kReverseShift * reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
@@ -5058,8 +5063,7 @@
LogSoftmax(params, input_shape, input_data, output_shape, output_data);
}
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
@@ -5068,13 +5072,13 @@
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -5310,22 +5314,21 @@
Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -6380,6 +6383,16 @@
output_map.array() = input1_map.array().min(min_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -6391,6 +6404,16 @@
output_map.array() = input1_map.array().max(max_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index 38aea14..ecc655c 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -26,6 +26,9 @@
namespace tflite {
namespace reference_ops {
+// TODO(b/80418076): Move to legacy ops file, along with invocations.
+static constexpr int kDepthwiseReverseShift = -1;
+
inline void DepthwiseConv(
const DepthwiseParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
@@ -95,7 +98,7 @@
acc += bias_data[oc];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -137,7 +140,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 87bcc8c..bb1d30b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -384,7 +384,7 @@
acc += bias_data[out_channel];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- kReverseShift * output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -422,7 +422,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -714,8 +715,7 @@
if (bias_data) {
acc += bias_data[out_c];
}
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- kReverseShift * output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -740,7 +740,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -793,8 +794,8 @@
// (16-bit, typically 3 integer bits) fixed-point format. The quantized
// multiplier and shift here have been pre-computed offline
// (e.g. by toco).
- accum = MultiplyByQuantizedMultiplier(accum, output_multiplier,
- -output_shift);
+ accum =
+ MultiplyByQuantizedMultiplier(accum, output_multiplier, output_shift);
// Saturate, cast to int16, and store to output array.
accum = std::max(accum, output_activation_min - output_offset);
accum = std::min(accum, output_activation_max - output_offset);
@@ -820,7 +821,8 @@
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -919,8 +921,8 @@
// (16-bit, typically 3 integer bits) fixed-point format. The quantized
// multiplier and shift here have been pre-computed offline
// (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -971,7 +973,7 @@
// quantized multiplier and shift here have been pre-computed offline
// (e.g. by toco).
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -996,7 +998,8 @@
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
tflite::FullyConnectedParams op_params;
op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
op_params.quantized_activation_min = output_activation_min;
op_params.quantized_activation_max = output_activation_max;
@@ -1154,6 +1157,7 @@
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
@@ -1912,7 +1916,7 @@
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1953,7 +1957,7 @@
const uint8* input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2017,7 +2021,7 @@
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2057,7 +2061,7 @@
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2408,6 +2412,125 @@
output_data, output_dims);
}
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches =
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
+ const int height =
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
+ const int width =
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ const int intern_activ_depth =
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+
+ // Concatenate prev_activ and input data together
+ std::vector<float const*> concat_input_arrays_data;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
+ concat_input_arrays_data.push_back(input_data);
+ concat_input_arrays_data.push_back(prev_activ_data);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
+
+ // Fully connected
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
+
+ // Memory state update (the LSTM "guts")
+ for (int b = 0; b < batches; ++b) {
+ for (int w = 0; w < width; ++w) {
+ for (int h = 0; h < height; ++h) {
+ for (int c = 0; c < output_depth; ++c) {
+ const float input_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 0 * output_depth + c)]));
+ const float new_input = std::tanh(activ_temp_data[Offset(
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
+ const float forget_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 2 * output_depth + c)]));
+ const float output_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 3 * output_depth + c)]));
+ const float new_state =
+ input_gate * new_input +
+ forget_gate *
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
+ output_gate * std::tanh(new_state);
+ }
+ }
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
const float* prev_activ_data,
const Dims<4>& prev_activ_dims, const float* weights_data,
@@ -2418,77 +2541,17 @@
const Dims<4>& output_activ_dims, float* concat_temp_data,
const Dims<4>& concat_temp_dims, float* activ_temp_data,
const Dims<4>& activ_temp_dims) {
- const int batches =
- MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
- output_state_dims, 3, output_activ_dims, 3);
- const int height =
- MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
- output_state_dims, 2, output_activ_dims, 2);
- const int width =
- MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
- output_state_dims, 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
- const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
- const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
- const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
- // Concatenate prev_activ and input data together
- std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
- concat_input_arrays_data.push_back(input_data);
- concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
-
- // Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
-
- // Memory state update (the LSTM "guts")
- for (int b = 0; b < batches; ++b) {
- for (int w = 0; w < width; ++w) {
- for (int h = 0; h < height; ++h) {
- for (int c = 0; c < output_depth; ++c) {
- const float input_gate =
- 1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 0 * output_depth + c, w, h, b)]));
- const float new_input = std::tanh(activ_temp_data[Offset(
- activ_temp_dims, 1 * output_depth + c, w, h, b)]);
- const float forget_gate =
- 1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 2 * output_depth + c, w, h, b)]));
- const float output_gate =
- 1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 3 * output_depth + c, w, h, b)]));
- const float new_state =
- input_gate * new_input +
- forget_gate *
- prev_state_data[Offset(prev_state_dims, c, w, h, b)];
- output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
- output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
- output_gate * std::tanh(new_state);
- }
- }
- }
- }
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
}
// Quantized LSTM cell implementation.
@@ -2576,52 +2639,90 @@
// aiming for 16-bit fixed-point quantization of these internal nodes here.
//
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -2727,6 +2828,37 @@
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
template <typename Scalar>
void Split(const SplitParams& params, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape* const* output_shapes,
@@ -3464,7 +3596,7 @@
std::max(diff_min - 1, // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- kReverseShift * reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff =
@@ -3505,8 +3637,7 @@
LogSoftmax(params, input_shape, input_data, output_shape, output_data);
}
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3517,13 +3648,13 @@
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -3609,9 +3740,8 @@
Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3621,13 +3751,13 @@
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -4603,6 +4733,16 @@
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4615,6 +4755,16 @@
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T, typename Op>
void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
@@ -4690,6 +4840,16 @@
std::greater<T1>());
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const RuntimeShape& input2_shape, const T3* input2_data,
+ const RuntimeShape& output_shape, T2* output_data) {
+ // Drop shape of second input: not needed.
+ ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Transpose(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const T* input_data,
@@ -4959,9 +5119,11 @@
op_params.left_shift = left_shift;
op_params.input1_offset = input1_offset;
op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input1_shift = kReverseShift * input1_shift;
op_params.input2_offset = input2_offset;
op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input2_shift = kReverseShift * input2_shift;
ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
@@ -5093,9 +5255,11 @@
op_params.left_shift = left_shift;
op_params.input1_offset = input1_offset;
op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input1_shift = kReverseShift * input1_shift;
op_params.input2_offset = input2_offset;
op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input2_shift = kReverseShift * input2_shift;
BroadcastComparison4DSlowWithScaling<T, F>(
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index b70a87d..3e03087 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -18,6 +18,7 @@
#include <cstring>
#include <iterator>
+#include "absl/base/macros.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -424,7 +425,7 @@
return flat_size;
}
-// Deprecated. Prefer FlatSize.
+ABSL_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 6e35799..2f4b663 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -158,7 +158,9 @@
AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
- AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
AddBuiltin(BuiltinOperator_RNN, Register_RNN());
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 3a53430..61e9106 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ // TODO(b/116063589): To return a working TF GraphDef, we should be returning
+ // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
+ // the conv since TF doesn't support dilations.
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 84680b9..aba7536 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -38,7 +38,7 @@
examples below use `tflite_convert` for simplicity.
* Example: `tflite_convert --output_file=...`
* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow
- repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository)
+ repository](https://www.tensorflow.org/install/source)
and use `bazel`. This is the recommended approach for converting models that
utilize new features that were not supported by TOCO in TensorFlow 1.9.
* Example: `bazel run
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 51f808d..910fa4c 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -260,7 +260,7 @@
In order to run the latest version of the TOCO Python API, clone the TensorFlow
repository, configure the installation, and build and install the pip package.
Detailed instructions are available
-[here](https://www.tensorflow.org/install/install_sources).
+[here](https://www.tensorflow.org/install/source).
### Converting models prior to TensorFlow 1.9. <a name="pre-tensorflow-1.9"></a>
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index fdd0632..4d213b3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -133,7 +133,6 @@
DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
-DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
@@ -266,6 +265,17 @@
bool has_default_ranges_flag_ = false;
};
+class IdentifyDilatedConv : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "IdentifyDilatedConv"; }
+ bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
+ void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
+
+ private:
+ bool identify_depthwise_conv_ = true;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857c..aac77eb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
- const auto it = model->operators.begin() + op_index;
- auto* stb_op = it->get();
-
- // 1. IDENTIFY OPERATORS
- // ***************************************************************************
- // SpaceToBatch Op.
- if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
- }
- if (stb_op->inputs.size() != 3) {
- return false;
- }
- CHECK_EQ(stb_op->outputs.size(), 1);
- // Extract the dilation factor from Input[1] of SpaceToBatch
- // TODO(mjmatthews): Support 2D dilation factors.
- const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
- if (!block_shape_array.buffer) {
- return false;
- }
- CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
- int dilation_factor =
- block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
- // Expand Op
- auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
- if (!post_stb_op) {
- return false;
- }
- bool has_expand_op = false;
- if (post_stb_op->type == OperatorType::kExpandDims) {
- has_expand_op = true;
- CHECK_EQ(post_stb_op->inputs.size(), 2);
- CHECK_EQ(post_stb_op->outputs.size(), 1);
- }
-
- // Conv Op
- const string& input_of_conv_op =
- has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
- auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
- if (conv_base_op->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+ Operator* post_stb_op, bool has_expand_op,
+ int dilation_factor) {
+ auto* conv_op = static_cast<T*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
@@ -158,8 +119,6 @@
CHECK_EQ(bias_add_op->inputs.size(), 2);
CHECK_EQ(bias_add_op->outputs.size(), 1);
- LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
// 2. RE-WIRE OPERATORS
// ***************************************************************************
// Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
- LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
- << conv_op->outputs[0] << "\".";
return true;
}
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* stb_op = it->get();
+
+ // 1. IDENTIFY OPERATORS
+ // ***************************************************************************
+ // SpaceToBatch Op.
+ if (stb_op->type != OperatorType::kSpaceToBatchND) {
+ return false;
+ }
+ if (stb_op->inputs.size() != 3) {
+ return false;
+ }
+ CHECK_EQ(stb_op->outputs.size(), 1);
+ // Extract the dilation factor from Input[1] of SpaceToBatch
+ // TODO(mjmatthews): Support 2D dilation factors.
+ const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+ if (!block_shape_array.buffer) {
+ return false;
+ }
+ CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+ int dilation_factor =
+ block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+ // Expand Op
+ auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+ if (!post_stb_op) {
+ return false;
+ }
+ bool has_expand_op = false;
+ if (post_stb_op->type == OperatorType::kExpandDims) {
+ has_expand_op = true;
+ CHECK_EQ(post_stb_op->inputs.size(), 2);
+ CHECK_EQ(post_stb_op->outputs.size(), 1);
+ }
+
+ // Conv Op
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ bool changed = false;
+ if (conv_base_op->type == OperatorType::kConv) {
+ changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+ post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ } else if (identify_depthwise_conv_ &&
+ conv_base_op->type == OperatorType::kDepthwiseConv) {
+ changed = ResolveDilatedConv<DepthwiseConvOperator>(
+ model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO)
+ << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ }
+
+ return changed;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 6c72e20..f943da6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
- op->stride_height, 1, 1, op->padding.type,
+ op->stride_height, op->dilation_width_factor,
+ op->dilation_height_factor, op->padding.type,
model->GetArray(output_name).mutable_shape(),
&op->padding.GetOrCreateFixedPadding());
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index 8266e2c..8e150db 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -25,29 +25,57 @@
namespace toco {
+namespace {
+
+void RenameArray(Model* model, const string& oldname,
+ const string& desired_newname) {
+ const string& newname = AvailableArrayName(*model, desired_newname);
+ auto& arrays = model->GetMutableArrayMap();
+ arrays[newname] = std::move(arrays[oldname]);
+ arrays.erase(oldname);
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == oldname) {
+ input = newname;
+ }
+ }
+ for (string& output : op->outputs) {
+ if (output == oldname) {
+ output = newname;
+ }
+ }
+ }
+}
+
+} // namespace
+
// Reorder the elements of an input_array according to the input_axes_order and
// output_axes_order. Then adjust the shapes of the input and output arrays
// accordingly. Note that input_array must have a buffer (that is, it is a
// constant array).
template <typename T, ArrayDataType DataType>
void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
- Array* input_array, Array* output_array) {
- CHECK(input_array->buffer->type == DataType);
- CHECK(!output_array->buffer);
- auto& input_data = input_array->GetMutableBuffer<DataType>().data;
- std::vector<T> reordered_data;
- reordered_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+ const Array& input_array, Array* output_array) {
+ DCHECK(input_array.buffer->type == DataType);
+ DCHECK(!output_array->buffer);
+ const auto& input_data = input_array.GetBuffer<DataType>().data;
+ auto& output_data = output_array->GetMutableBuffer<DataType>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
// TODO(b/62904716) Shapes should be used directly.
- Shape input_shape = input_array->shape();
+ Shape input_shape = input_array.shape();
Shape output_shape = output_array->shape();
if (AxesCount(input_axes_order) == 2) {
UnextendShape(&input_shape, 2);
UnextendShape(&output_shape, 2);
}
ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
- input_data.data(), reordered_data.data());
- input_data = reordered_data;
- input_array->copy_shape(output_array->shape());
+ input_data.data(), output_data.data());
+ if (input_array.minmax) {
+ output_array->GetOrCreateMinMax() = input_array.GetMinMax();
+ }
+ if (input_array.narrow_range) {
+ output_array->narrow_range = true;
+ }
}
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
@@ -57,8 +85,11 @@
return false;
}
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
- const auto& input_array_name = reorder_op->inputs[0];
- const auto& output_array_name = reorder_op->outputs[0];
+
+ // Intentionally copies, not references.
+ const string input_array_name = reorder_op->inputs[0];
+ const string output_array_name = reorder_op->outputs[0];
+
auto& input_array = model->GetArray(input_array_name);
auto& output_array = model->GetArray(output_array_name);
if (!input_array.buffer) {
@@ -72,31 +103,23 @@
if (input_array.buffer->type == ArrayDataType::kFloat) {
ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
- } else if (input_array.buffer->type == ArrayDataType::kInt32) {
+ input_array, &output_array);
+ } else if (input_array.buffer->type == ArrayDataType::kUint8) {
+ // TODO(benoitjacob): This path seems unused.
+ // ReorderAxes is only used when importing from
+ // TensorFlow GraphDef, which does not support quantized nodes.
ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
+ input_array, &output_array);
} else {
LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8.";
}
- input_array.copy_shape(output_array.shape());
-
- // Update the edges of the graph to point to the input array
- for (const auto& other_op : model->operators) {
- for (auto& input : other_op->inputs) {
- if (input == output_array_name) {
- input = input_array_name;
- }
- }
- }
-
AddMessageF("Reordered axes for array %s", input_array_name);
- // Remove the op and output array.
- model->EraseArray(output_array_name);
- model->operators.erase(it);
+ DeleteOpAndArraysIfUnused(model, op);
+ RenameArray(model, output_array_name, input_array_name);
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index fcf30bd..65346c4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -24,6 +24,37 @@
namespace toco {
+namespace {
+
+TransposeOperator* FindTransposeOpWithInput(const Model& model,
+ const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ Operator* op = it->get();
+ if (op->type != OperatorType::kTranspose) {
+ continue;
+ }
+ if (op->inputs[0] != array_name) {
+ continue;
+ }
+ const auto& permutation_array = model.GetArray(op->inputs[1]);
+ if (permutation_array.data_type != ArrayDataType::kInt32) {
+ continue;
+ }
+ const auto& permutation_data =
+ permutation_array.GetBuffer<ArrayDataType::kInt32>().data;
+ if (permutation_data.size() != 2) {
+ continue;
+ }
+ if (permutation_data[0] != 1 || permutation_data[1] != 0) {
+ continue;
+ }
+ return static_cast<TransposeOperator*>(op);
+ }
+ return nullptr;
+}
+
+} // namespace
+
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
@@ -37,7 +68,13 @@
// TransposeOperator. However, the second input is supposed to be 2D, so we
// can actually handle transposition of that matrix, which happens to be more
// common anyway.
- CHECK(!matmul_op->transpose_a);
+ if (matmul_op->transpose_a) {
+ AddMessageF(
+ "Not replacing %s by a FullyConnected operator, because it has "
+ "the transpose_a attribute",
+ LogName(*matmul_op));
+ return false;
+ }
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
@@ -46,18 +83,35 @@
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
if (!matmul_op->transpose_b) {
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(model,
- AvailableArrayName(
- *model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
-
+ // Need to transpose input_rhs, by inserting a TransposeOperator.
+ // First, check if there already is a TransposeOperator transposing that
+ // array, so we can just reuse it.
+ auto* transpose_op = FindTransposeOpWithInput(*model, input_rhs);
+ if (!transpose_op) {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, created new "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ // No such TransposeOperator found. Create one now.
+ transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ input_rhs,
+ CreateInt32Array(
+ model, AvailableArrayName(*model, input_rhs + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, input_rhs + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+ // Sanity check
+ DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+ } else {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, reused existing "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ }
+ // Re-wire: have the matmul consume the transposed array.
input_rhs = transpose_op->outputs[0];
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 4c678e7..e02d000 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -641,6 +641,23 @@
CHECK_EQ(strides.i(3), 1);
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
+ if (HasAttr(node, "dilations")) {
+ const auto& dilations = GetListAttr(node, "dilations");
+ TF_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+ }
+ conv->dilation_height_factor = dilations.i(1);
+ conv->dilation_width_factor = dilations.i(2);
+ } else {
+ conv->dilation_height_factor = 1;
+ conv->dilation_width_factor = 1;
+ }
const auto& padding = GetStringAttr(node, "padding");
if (padding == "SAME") {
conv->padding.type = PaddingType::kSame;
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 0fd2732..6e207fd 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -2084,6 +2084,7 @@
}
}
const ArrayMap& GetArrayMap() const { return arrays; }
+ ArrayMap& GetMutableArrayMap() { return arrays; }
int64 ArithmeticOpsCount() const { return ops_count; }
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a7c1715..a08b024 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -101,7 +101,6 @@
transformations->Add(new ResolveTensorFlowSwitch);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
- transformations->Add(new IdentifyDilatedConv);
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
transformations->Add(new IdentifyRelu1);
@@ -282,6 +281,14 @@
}
}
transformations.Add(new ResolveConstantConcatenation);
+ // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
+ // conv, so we need to make sure we don't convert to dilated depthwise conv
+ // when outputing to TF GraphDef.
+ auto* identify_dilated_conv = new IdentifyDilatedConv;
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ identify_dilated_conv->set_identify_depthwise_conv(false);
+ }
+ transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations",
transformations);
@@ -367,9 +374,7 @@
}
// Deduplicate large constant arrays.
- if (toco_flags.has_dedupe_array_min_size_bytes()) {
- DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
- }
+ DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
index 4929133..80cdb2f 100644
--- a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -36,7 +36,7 @@
"source": [
"## Overview\n",
"\n",
- "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n",
+ "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
"converting weights to 8 bit precision as part of model conversion from\n",
"tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
"achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
@@ -542,7 +542,7 @@
},
"outputs": [],
"source": [
- "print(eval_model(interpreter_quant, mnist_ds))"
+ "print(eval_model(interpreter, mnist_ds))"
]
},
{
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index 25ec475..dab1e02 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -31,7 +31,7 @@
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
- [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/contrib/quantization/README.md b/tensorflow/contrib/quantization/README.md
index 359950a..826e8db 100644
--- a/tensorflow/contrib/quantization/README.md
+++ b/tensorflow/contrib/quantization/README.md
@@ -2,6 +2,6 @@
If you are looking for quantized training rewrites that allow for training
quantized models that work with
-[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/), you should look at
+[TensorFlow Lite](https://www.tensorflow.org/lite/), you should look at
the [contrib/quantize](https://www.tensorflow.org/api_docs/python/tf/contrib/quantize)
package.
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index c59f667..23e3a25 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -20,9 +20,13 @@
srcs_version = "PY2AND3",
deps = [
":common",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index 3f1e7d2..0ab19c9 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -105,7 +105,7 @@
--std_value=127.5 --mean_value=127.5
```
-See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../mobile/tflite/).
+See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/).
## Quantized accuracy results
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index b27117d..e6c04bc 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -34,10 +34,10 @@
'ScalarSummary')
# Valid activation ops for quantization end points.
-_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity']
+_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity']
# Regular expression for recognizing nodes that are part of batch norm group.
-_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm')
+_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm')
def BatchNormGroups(graph):
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 2b26302..a3ce041 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -13,21 +13,26 @@
# limitations under the License.
# ==============================================================================
"""Tests for common utilities in this package."""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+batch_norm = layers.batch_norm
+conv2d = layers.conv2d
+
class CommonTest(test_util.TensorFlowTestCase):
@@ -87,6 +92,56 @@
for i in inputs:
self.assertIn(i, op.inputs)
+ def testBatchNormScope(self):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ g = ops.Graph()
+ with g.as_default():
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ stride = 1
+ out_depth = 32
+ scope = ''
+ node = conv2d(
+ inputs,
+ out_depth, [2, 2],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(False),
+ scope=scope)
+
+ node = nn_ops.relu(node, name='Relu6')
+ bn_list = common.BatchNormGroups(g)
+ with open('/tmp/common_test.pbtxt', 'w') as f:
+ f.write(str(g.as_graph_def()))
+
+ # Exactly one batch norm layer with empty scope should be found
+ self.assertEqual(len(bn_list), 1)
+ self.assertEqual(bn_list[0], '')
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ return params
+
+ def _WeightInit(self, stddev):
+ """Returns a truncated normal variable initializer.
+
+ Function is defined purely to shorten the name so that it stops wrapping.
+
+ Args:
+ stddev: Standard deviation of normal variable.
+
+ Returns:
+ An initializer that initializes with a truncated normal variable.
+ """
+ return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 2971b28..e5790a6 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -95,8 +95,7 @@
_ComputeBatchNormCorrections(
context='',
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=True))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
@@ -296,8 +295,7 @@
batch_to_space_op=batch_to_space_op)
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
- fused_batch_norm):
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
"""Computes batch norm correction params.
Before batch normalization is frozen:
@@ -327,14 +325,14 @@
computation.
freeze_batch_norm_delay: Delay in steps at which computation switches
from regular batch norm to frozen mean and variance.
- fused_batch_norm: Bool, true if fused batch norm is used.
+
Returns:
A tuple of correction_scale, correction_recip, correction_offset
"""
g = ops.get_default_graph()
- prefix = '' if not context else context + '/'
+ prefix = '' if not context else context
with g.name_scope(prefix + 'batch_norm_correction'):
recip_sigma_mv = math_ops.rsqrt(
match.moving_variance_tensor + match.batch_epsilon)
@@ -495,8 +493,23 @@
# Treat consumer ops in bypass modules differently since they have Add
# operations instead of Relu* above.
- add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
- add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+ # Changes to make sure that the correct scope is selected for the bypass add
+ # The rule here is that if the scope is of the form: str1/str2 for the
+ # batch norm,
+ # the bypass add is at scope str1. If bn is of scope just str1, then the
+ # bypass add is at scope ''.
+ # If there is no batch norm, then there is no bypass add.
+ add_bypass_ctx = ''
+ if bn:
+ try:
+ add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+ except AttributeError:
+ add_bypass_ctx = ''
+
+ if add_bypass_ctx:
+ add_bypass_ctx = add_bypass_ctx + '/'
+
+ add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
nodes_modified_count = common.RerouteTensor(
folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
@@ -505,8 +518,8 @@
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
@@ -538,7 +551,8 @@
if op.name.endswith(match_pattern):
split_name = op.name.split('/')
num_matches = len(set(split_name) & split_context)
- if num_matches > 0:
+
+ if num_matches > 0 or not scope:
match_dict[op.name] = num_matches
# match_dict contains matching op names from graph with values being
# number of matches to scope. We pick the key with the most matches
@@ -597,21 +611,21 @@
# op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
# will have 2 matches,scope with a different conv layer will have one match.
- op_suffix_mean = '/BatchNorm/moments/Squeeze'
- op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
- op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
- op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
+ op_suffix_mean = 'BatchNorm/moments/Squeeze'
+ op_suffix_variance = 'BatchNorm/moments/Squeeze_1'
+ op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y'
+ op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay'
+ op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay'
if variable_scope.get_variable_scope().use_resource:
- op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
+ op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp'
op_suffix_moving_variance = (
- '/BatchNorm/moving_variance/Read/ReadVariableOp')
- op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
+ 'BatchNorm/moving_variance/Read/ReadVariableOp')
+ op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp')
else:
- op_suffix_gamma = '/BatchNorm/gamma'
- op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
- op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
+ op_suffix_gamma = 'BatchNorm/gamma'
+ op_suffix_moving_variance = 'BatchNorm/moving_variance/read'
+ op_suffix_moving_mean = 'BatchNorm/moving_mean/read'
# Parse through list of ops to find relevant ops
batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
@@ -679,8 +693,7 @@
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
- mul_scale = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm_1/' +
+ mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
# Skip over the BatchToSpace operation in the case of atrous convolutions.
@@ -697,8 +710,7 @@
_ComputeBatchNormCorrections(
context=context,
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=False))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# Special handling for weights of depthwise convolution.
if op_below.type == 'DepthwiseConv2dNative':
new_shape = [
@@ -706,27 +718,27 @@
weights.get_shape().as_list()[3]
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
- scale = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/' + scale_name)
+ scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
+ scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
- context + '/scale_reshape')
+ context + 'scale_reshape')
if correction_scale is not None:
correction_scale = array_ops.reshape(correction_scale, new_shape,
- context + '/correction_reshape')
+ context + 'correction_reshape')
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
+ context + 'correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights),
- (1, scale)])
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights),
+ (1, scale)])
elif op_below.type in ['Conv2D', 'MatMul']:
if correction_scale is not None:
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
+ context + 'correction_mult')
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)])
else:
raise ValueError('Cannot handle operation of type: %s' % op_below.type)
_AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
@@ -734,8 +746,8 @@
conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
[(1, mul_fold.outputs[0])])
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
# Copy the batch to space operation if we have a atrous convolution.
@@ -748,10 +760,10 @@
if correction_offset is not None:
with ops.device(conv_or_fc_folded.device):
corrected_output = math_ops.multiply(correction_recip, corrected_output,
- context + '/post_conv_mul')
+ context + 'post_conv_mul')
corrected_output = math_ops.add(corrected_output, (correction_offset),
- context + '/correction_add')
- add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)])
+ context + 'correction_add')
+ add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)])
_AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
return add_shift, add_fold
@@ -930,7 +942,7 @@
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
+ rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index e88db0a..5e63d33 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -97,8 +97,11 @@
layer_match.activation_op)
add_context = context
if layer_match.bypass_op:
- add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
-
+ pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
+ if pattern_match_result is not None:
+ add_context = pattern_match_result.group(1)
+ else:
+ add_context = ''
# If `scope` is given, only quantize it if the producer of weights
# (usually it's the layer op) is in the right scope.
_InsertQuantOp(
@@ -156,8 +159,12 @@
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
- post_activation_bypass_context = re.search(
- r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+ pattern_match_result = re.search(
+ r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
+ if pattern_match_result is not None:
+ post_activation_bypass_context = pattern_match_result.group(1)
+ else:
+ post_activation_bypass_context = ''
# If `scope` is given, only quantize it if the producer is in the right
# scope.
# Make sure the op following this isn't an activation. In which case, we
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 31a2955..f6bf57a 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -58,85 +58,102 @@
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], False)
- test_fn(params[0], params[1], params[2], params[3], True)
+ test_fn(params[0], params[1], params[2], params[3], False, None)
+ test_fn(params[0], params[1], params[2], params[3], True, None)
+ # Test with both empty scope and an example scope
+ test_fn(params[0], params[1], params[2], params[3], False, 'test')
+ test_fn(params[0], params[1], params[2], params[3], True, 'test')
def _AssertCorrectQuantizedGraphWithoutBatchNorm(
self, graph, scope, layer, activation_op_name, with_bypass, delay,
use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
- weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
- quantization_node_name)
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+ weights_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
# Assemble the expected inputs.
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp')
else:
- expected_inputs.append(scope + '/' + layer + '/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp')
else:
expected_inputs = [
- scope + '/weights_quant/AssignMinLast',
- scope + '/weights_quant/AssignMaxLast',
+ conv_scope + delim + 'weights_quant/AssignMinLast',
+ conv_scope + delim + 'weights_quant/AssignMaxLast',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise_weights/read')
+ expected_inputs.append(conv_scope + delim + 'depthwise_weights/read')
else:
- expected_inputs.append(scope + '/weights/read')
+ expected_inputs.append(conv_scope + delim + 'weights/read')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if delay and delay > 0:
- output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
+ output_op_name = (
+ conv_scope + delim + 'weights_quant/delayed_quant/Switch_1')
else:
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + '/depthwise'
+ output_op_name = conv_scope + delim + 'depthwise'
else:
- output_op_name = scope + '/' + layer
+ output_op_name = conv_scope + delim + layer
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
- conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
- quantization_node_name)
+ conv_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- scope + '/BiasAdd',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim + 'BiasAdd',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
+ conv_scope + delim + 'BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
- output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
- if delay else 'test/Add')
+
+ output_op_name = (
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name('test/act_quant/' +
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- 'test/' + activation_op_name,
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + activation_op_name,
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
- 'test/' + activation_op_name
+ scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma',
+ scope + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -145,7 +162,8 @@
self._TestQuantize_Conv2dWithoutBatchNorm)
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> Conv2d no batch norm -> Activation.
Args:
@@ -156,6 +174,7 @@
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -165,7 +184,9 @@
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -173,16 +194,19 @@
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, True, quant_delay=delay)
+ if conv_scope is None:
+ conv_scope = ''
+
self._AssertCorrectQuantizedGraphWithoutBatchNorm(
graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
use_resource)
@@ -192,7 +216,7 @@
self._TestQuantize_FCWithoutBatchNorm)
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource, scope):
"""Tests quantization: inputs -> FC no batch norm -> Activation.
Args:
@@ -203,6 +227,7 @@
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -211,16 +236,18 @@
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ fc_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
- scope=scope)
+ scope=fc_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -235,7 +262,8 @@
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ self, activation, activation_op_name, with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
Args:
@@ -246,6 +274,7 @@
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -254,7 +283,10 @@
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -263,10 +295,10 @@
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -280,8 +312,9 @@
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_AtrousConvWithoutBatchNorm)
- def _TestQuantize_AtrousConvWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation,
+ activation_op_name, with_bypass,
+ delay, use_resource, scope):
"""Tests quantization: inputs -> atrous conv no batch norm -> Activation.
Args:
@@ -292,6 +325,7 @@
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -300,7 +334,10 @@
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -309,10 +346,10 @@
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -353,78 +390,96 @@
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], params[4], False)
- test_fn(params[0], params[1], params[2], params[3], params[4], True)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], True, None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ 'test')
+ test_fn(params[0], params[1], params[2], params[3], params[4], True,
+ 'test')
def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer,
activation_op_name, with_bypass,
delay, use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+
weights_quant = graph.get_operation_by_name(
- scope + '/weights_quant/' + quantization_node_name)
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
+
self.assertEqual(weights_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/weights_quant/' + 'AssignMinLast',
- scope + '/weights_quant/' + 'AssignMaxLast'
+ conv_scope + delim + 'weights_quant/' + 'AssignMinLast',
+ conv_scope + delim + 'weights_quant/' + 'AssignMaxLast'
]
- expected_inputs.append(scope + '/mul_fold')
+ expected_inputs.append(conv_scope + delim + 'mul_fold')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/depthwise_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold')
else:
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/' + layer + '_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
conv_quant = graph.get_operation_by_name(
- scope + '/conv_quant/' + quantization_node_name)
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma',
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
]
- expected_inputs.append(scope + '/add_fold')
+ expected_inputs.append(conv_scope + delim + 'add_fold')
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (
- scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add')
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name(
- 'test/act_quant/' + quantization_node_name)
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
+ quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma',
- 'test/act_quant/AssignMaxEma',
+ scope + 'act_quant/AssignMinEma',
+ scope + 'act_quant/AssignMaxEma',
]
- expected_inputs.append('test/' + activation_op_name)
+ expected_inputs.append(scope + activation_op_name)
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -433,7 +488,7 @@
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
@@ -445,6 +500,7 @@
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -453,7 +509,9 @@
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -463,13 +521,13 @@
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -487,7 +545,7 @@
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
@@ -499,6 +557,7 @@
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -506,7 +565,9 @@
batch_size, depth = 5, 256
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
@@ -514,13 +575,13 @@
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -540,7 +601,7 @@
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
@@ -552,6 +613,7 @@
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -559,7 +621,9 @@
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -570,13 +634,13 @@
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -595,7 +659,7 @@
def _TestQuantize_AtrousConvWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> atrous conv with batch norm -> Activation.
Args:
@@ -607,6 +671,7 @@
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -614,7 +679,10 @@
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -625,13 +693,13 @@
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -718,6 +786,18 @@
with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
f.write(str(graph.as_graph_def()))
+ def _GetConvScope(self, scope, with_bypass):
+ if scope is None:
+ scope = ''
+ delim = '/' if scope else ''
+
+ if with_bypass:
+ conv_scope = scope + delim + 'test2'
+ else:
+ conv_scope = scope
+
+ return conv_scope
+
def _BatchNormParams(self, fused=False, force_updates=False):
params = {
'center': True,
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
index 0890810..3dee163 100644
--- a/tensorflow/contrib/rate/rate_test.py
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -46,7 +46,7 @@
@test_util.run_in_graph_and_eager_modes()
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
r_ = rate.Rate()
a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
self.evaluate(variables.global_variables_initializer())
@@ -67,7 +67,7 @@
@test_util.run_in_graph_and_eager_modes()
def testWhileLoop(self):
- with self.test_session():
+ with self.cached_session():
r_ = rate.Rate()
def body(value, denom, i, ret_rate):
diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
index 6253f96..e30e725 100644
--- a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
+++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
@@ -210,7 +210,7 @@
# Input data shape is not defined over a 2D grid, i.e. its shape is not like
# (batch_size, data_height, data_width, data_channels).
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_depth,
data_channels)
data = np.zeros(data_shape)
@@ -225,7 +225,7 @@
sess.run(outputs)
# Warp tensor must be at least a matrix, with shape [batch_size, 2].
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size,)
@@ -238,7 +238,7 @@
sess.run(outputs)
# The batch size of the data and warp tensors must be the same.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size+1, warp_height, warp_width, 2)
@@ -252,7 +252,7 @@
# The warp tensor must contain 2D coordinates, i.e. its shape last dimension
# must be 2.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size, warp_height, warp_width, 3)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 1c23c28..0d61592 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -49,7 +49,7 @@
return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
def testScalarHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
@@ -63,7 +63,7 @@
self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
@@ -83,7 +83,7 @@
self.assertEqual(b'', status_message_values)
def testEmptyHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = []
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
@@ -98,7 +98,7 @@
'/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(self.rpc(method=method, address=self._address, request=''))
@@ -111,7 +111,7 @@
def testInvalidAddress(self):
# This covers the case of address='' and address='localhost:293874293874'
address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
@@ -128,7 +128,7 @@
self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors = self.rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
@@ -150,7 +150,7 @@
self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
def testSometimesFailingMethodWithManyRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Fail hard by default.
response_tensors = self.rpc(
method=self.get_method_name('SometimesFailWithInvalidArgument'),
@@ -179,7 +179,7 @@
self.assertAllEqual(expected_message_values, status_message_values)
def testVecHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -197,7 +197,7 @@
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -219,7 +219,7 @@
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['values'],
@@ -241,7 +241,7 @@
for i in range(20)], response_shape_values)
def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -254,7 +254,7 @@
sess.run(response_tensors, options=options)
def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -265,7 +265,7 @@
sess.run(response_tensors)
def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('SometimesSleepForever'),
timeout_in_ms=1000,
@@ -281,7 +281,7 @@
def testTryRpcWithMultipleAddressesSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
@@ -301,7 +301,7 @@
def testTryRpcWithMultipleMethodsSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
methods = flatten(
[[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
@@ -319,7 +319,7 @@
def testTryRpcWithMultipleAddressesAndRequests(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 4fc36d8..c669ced 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -355,11 +355,15 @@
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& saved_model_tags,
- SavedModelBundle* saved_model_bundle) {
+ SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = false;
+ }
if (MaybeSavedModelDirectory(export_dir)) {
LOG(INFO)
<< "Attempting to load native SavedModelBundle in bundle-shim from: "
<< export_dir;
+
return LoadSavedModel(session_options, run_options, export_dir,
saved_model_tags, saved_model_bundle);
} else if (IsPossibleExportDirectory(export_dir)) {
@@ -368,6 +372,9 @@
LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
"in bundle-shim from: "
<< export_dir;
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = true;
+ }
return LoadSavedModelFromLegacySessionBundlePath(
session_options, run_options, export_dir, saved_model_bundle);
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index 4628b6a..7f0f995 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -59,11 +59,13 @@
} // namespace internal
// Loads a SavedModel from either a session-bundle path or a SavedModel bundle
-// path.
+// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
+// SavedModel was up-converted and loaded from a SessionBundle.
+// `is_session_bundle` value should not be used if error is returned.
Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir, const std::unordered_set<string>& tags,
- SavedModelBundle* bundle);
+ SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 9a1dd93..815beb7 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -63,12 +63,16 @@
void LoadAndValidateSavedModelBundle(const string& export_dir,
const std::unordered_set<string>& tags,
- const string& signature_def_key) {
+ const string& signature_def_key,
+ bool expect_session_bundle) {
SessionOptions session_options;
RunOptions run_options;
SavedModelBundle saved_model_bundle;
+ bool is_session_bundle = false;
TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
- session_options, run_options, export_dir, tags, &saved_model_bundle));
+ session_options, run_options, export_dir, tags, &saved_model_bundle,
+ &is_session_bundle));
+ EXPECT_EQ(expect_session_bundle, is_session_bundle);
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
const auto& signature_def_map = meta_graph_def.signature_def();
@@ -512,7 +516,8 @@
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
- kDefaultServingSignatureDefKey);
+ kDefaultServingSignatureDefKey,
+ /*expect_session_bundle=*/true);
// Verify that the named signature is also present.
SessionOptions session_options;
@@ -558,7 +563,8 @@
const string saved_model_bundle_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
- {kSavedModelTagServe}, "regress_x_to_y");
+ {kSavedModelTagServe}, "regress_x_to_y",
+ /*expect_session_bundle=*/false);
}
// Checks a basic load fails with an invalid export path.
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
index e4db5f2..e6a0b30 100644
--- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -38,7 +38,7 @@
graph_def = graph.as_graph_def()
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
for _ in range(20):
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index ae8336d..807741e 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -52,7 +52,7 @@
summary_ops.histogram('histogram', [1.0], step=1)
summary_ops.image('image', [[[[1.0]]]], step=1)
summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
# The working condition of the ops is tested in the C++ test so we just
@@ -64,7 +64,7 @@
writer = summary_ops.create_file_writer(logdir, max_queue=0)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
events = summary_test_util.events_from_logdir(logdir)
@@ -77,7 +77,7 @@
with writer.as_default(), summary_ops.always_record_summaries():
with ops.name_scope('scope'):
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
events = summary_test_util.events_from_logdir(logdir)
@@ -90,7 +90,7 @@
writer = summary_ops.create_file_writer(logdir, max_queue=0)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(summary_ops.summary_writer_initializer_op())
step, _ = sess.run(
@@ -105,7 +105,7 @@
logdir, max_queue=1, flush_millis=999999)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
# Note: First tf.Event is always file_version.
@@ -123,7 +123,7 @@
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
flush_op = summary_ops.flush()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
# Note: First tf.Event is always file_version.
@@ -157,7 +157,7 @@
with writer3.as_default():
summary_ops.scalar('three', 3.0, step=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run init ops across writers sequentially to avoid race condition.
# TODO(nickfelt): fix race condition in resource manager lookup or create
sess.run(writer1.init())
@@ -191,7 +191,7 @@
logdir, max_queue=100, flush_millis=1000000)
with writer.as_default():
summary_ops.scalar('one', 1.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
self.assertEqual(1, get_total()) # file_version Event
@@ -219,7 +219,7 @@
logdir, max_queue=100, flush_millis=1000000)
with writer.as_default():
summary_ops.scalar('one', 1.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
self.assertEqual(1, get_total()) # file_version Event
@@ -241,7 +241,7 @@
training_util.get_or_create_global_step()
name = 'hi'
graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
- with self.test_session():
+ with self.cached_session():
with self.create_db_writer().as_default():
summary_ops.initialize(graph=graph)
six.assertCountEqual(self, [name],
@@ -249,7 +249,7 @@
def testScalarSummary(self):
"""Test record_summaries_every_n_global_steps and all_summaries()."""
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
global_step = training_util.get_or_create_global_step()
global_step.initializer.run()
with ops.device('/cpu:0'):
@@ -280,7 +280,7 @@
def testScalarSummaryNameScope(self):
"""Test record_summaries_every_n_global_steps and all_summaries()."""
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
global_step = training_util.get_or_create_global_step()
global_step.initializer.run()
with ops.device('/cpu:0'):
@@ -311,7 +311,7 @@
self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar')
def testSummaryGraphModeCond(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_file_writer(
@@ -332,7 +332,7 @@
self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar')
def testSummaryGraphModeWhile(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_file_writer(
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
index e429d12..1c4e18d 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -32,7 +32,7 @@
indices = [[1], [10]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(
@@ -45,7 +45,7 @@
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
@@ -57,7 +57,7 @@
indices = []
updates = []
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(init_val, input_data.eval())
@@ -67,7 +67,7 @@
input_data = variables.Variable(init_val)
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
with self.assertRaisesOpError(
'Number of updates should be same as number of indices.'):
@@ -80,7 +80,7 @@
indices = [[0, 0], [1, 1]]
updates = [[100., 200., 300.], [400., 500., 600.]]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[101., 202., 303.], [4., 5., 6.]],
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 1c9c818..e0f0c0d 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -149,7 +149,7 @@
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
resources.initialize_resources(resources.shared_resources()).run()
self.assertEquals(probs.eval().shape, (4, 2))
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
index 687dee0..caf8b6d 100644
--- a/tensorflow/contrib/tensorrt/README.md
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -26,4 +26,4 @@
In order to make use of TensorRT integration, you will need a local installation
of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt).
Installation instructions for compatibility with TensorFlow are provided on the
-[TensorFlow Installation page](https://www.tensorflow.org/install/install_linux#nvidia_requirements_to_run_tensorflow_with_gpu_support).
+[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide.
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index fe6f8b4..7ad9bf2 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -780,12 +780,12 @@
// If device is not set, use the first found GPU device for the conversion.
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
TfGpuId tf_gpu_id(tf_gpu_id_value);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
- << cuda_gpu_id.value();
- cuda_device_id = cuda_gpu_id.value();
+ << platform_gpu_id.value();
+ cuda_device_id = platform_gpu_id.value();
GPUOptions gpu_options;
// If the TF to Cuda gpu id mapping exist, the device and corresponding
// allocator must have been initialized already, so the
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 2b42d81..88cf8d5 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -565,21 +565,22 @@
new TRTInt8Calibrator(device_buffers_, batch_size, name()));
const string label(name());
auto segment_graph = &segment_graph_;
- const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
- if (cuda_gpu_id < 0) {
+ const int platform_gpu_id =
+ ctx->device()->tensorflow_gpu_device_info()->gpu_id;
+ if (platform_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
const int64 workspace_size_bytes = workspace_size_;
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
- cuda_gpu_id, workspace_size_bytes]() {
- VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
+ platform_gpu_id, workspace_size_bytes]() {
+ VLOG(0) << "Starting calibration thread on device " << platform_gpu_id
<< ", Calibration Resource @ " << cres;
- auto err = cudaSetDevice(cuda_gpu_id);
+ auto err = cudaSetDevice(platform_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
- LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
+ LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
<< " in calibration thread";
}
// ConvertGraphDefToEngine() will try to build the engine. This thread
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
index 84e3614..832d34d 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -63,7 +63,7 @@
(b"jumps", b"brown"),
(b"jumps", b"fox"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -94,7 +94,7 @@
(b"jumps", b"fox"),
(b"jumps", b"jumps"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -105,7 +105,7 @@
# If emit_self_as_target is False (default), output will be empty.
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, tokens.eval().size)
self.assertEqual(0, labels.eval().size)
@@ -117,7 +117,7 @@
(b"quick", b"quick"),
(b"brown", b"brown"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -134,7 +134,7 @@
(b"brown", b"the"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -150,7 +150,7 @@
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -165,7 +165,7 @@
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -196,7 +196,7 @@
(b"over", b"fox"),
(b"over", b"jumps"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
self.assertAllEqual(expected_labels, labels_eval)
@@ -222,7 +222,7 @@
tokens_2, labels_2 = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
[tokens_1, labels_1, tokens_2, labels_2])
@@ -244,7 +244,7 @@
(b"brown", b"fox"),
(b"fox", b"brown"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
@@ -269,7 +269,7 @@
(2, 3),
(3, 2),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -286,7 +286,7 @@
for min_skips, max_skips in invalid_skips:
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=min_skips, max_skips=max_skips)
- with self.test_session() as sess, self.assertRaises(
+ with self.cached_session() as sess, self.assertRaises(
errors.InvalidArgumentError):
sess.run([tokens, labels])
@@ -338,7 +338,7 @@
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
# No vocab_freq_table specified - output should be the same as input.
@@ -395,7 +395,7 @@
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
@@ -464,7 +464,7 @@
(b"life", b"and"),
(b"and", b"life"),
])
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -510,7 +510,7 @@
(b"to", b"life"),
(b"life", b"to"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lookup_ops.tables_initializer().run()
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 298ffc1..4e0b612 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -80,7 +80,7 @@
"tpu_embedding_ops",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
],
@@ -99,7 +99,7 @@
"ops/tpu_embedding_ops.cc",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/core:lib_proto_parsing",
],
)
@@ -351,7 +351,7 @@
tf_py_test(
name = "topology_test",
- size = "small",
+ size = "medium",
srcs = ["python/tpu/topology_test.py"],
additional_deps = [
":tpu",
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 72d37f7..18b9893 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_config.pb.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -88,12 +88,12 @@
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
+ int64 num_tables = config.table_descriptor_size();
if (table_id >= num_tables) {
return errors::InvalidArgument("Table id >= num_tables");
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
return Status::OK();
@@ -160,12 +160,12 @@
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
+ int64 num_tables = config.table_descriptor_size();
if (table_id >= num_tables) {
return errors::InvalidArgument("Table id >= num_tables");
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
TF_RETURN_IF_ERROR(
@@ -244,11 +244,11 @@
if (!config.ParseFromString(config_string)) {
return errors::InvalidArgument("Malformed tpu_embedding_config.");
}
- int64 batch_size = config.batch_size();
- int64 num_tables = config.table_config_size();
+ int64 batch_size = config.batch_size_per_tensor_core();
+ int64 num_tables = config.table_descriptor_size();
for (int table_id = 0; table_id < num_tables; ++table_id) {
- int64 width = config.table_config(table_id).width();
- int64 num_features = config.table_config(table_id).num_features();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_features = config.table_descriptor(table_id).vocabulary_size();
c->set_output(table_id, c->Matrix(batch_size * num_features, width));
}
return Status::OK();
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index 598b73b..c20cab8 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -10,12 +10,15 @@
)
tf_proto_library(
- name = "tpu_embedding_config_proto",
+ name = "tpu_embedding_configuration_proto",
srcs = [
- "tpu_embedding_config.proto",
+ "tpu_embedding_configuration.proto",
],
cc_api_version = 2,
- protodeps = [":optimization_parameters_proto"],
+ protodeps = [
+ ":tpu_embedding_output_layout_proto",
+ ":optimization_parameters_proto",
+ ],
visibility = ["//visibility:public"],
)
@@ -29,6 +32,15 @@
)
tf_proto_library(
+ name = "tpu_embedding_output_layout_proto",
+ srcs = [
+ "tpu_embedding_output_layout.proto",
+ ],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
name = "topology_proto",
srcs = [
"topology.proto",
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
deleted file mode 100644
index 3476cc8..0000000
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
+++ /dev/null
@@ -1,66 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.tpu;
-
-import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
-
-// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
-// and gradient updates separate from the TF Graph.
-message TPUEmbeddingConfiguration {
- // model_mode specifies whether the model is to be run in training or
- // inference. In inference mode, gradient updates to embedding tables are not
- // performed.
- enum ModelMode {
- INVALID = 0;
- TRAINING = 1;
- INFERENCE = 2;
- }
-
- ModelMode model_mode = 1;
-
- // num_hosts is the number of host CPU systems in the training/inference job.
- // Each embedding table must be sharded into num_hosts separate Variables,
- // placed separately on the num_hosts CPU devices in the cluster. Sharding
- // will be performed equivalently to the 'div' sharding_strategy option of
- // embedding_lookup() and embedding_lookup_sparse().
- int32 num_hosts = 2;
-
- // The total number of TensorNodes. This is equal to num_hosts times the
- // number of TensorNodes attached to each host.
- int32 num_tensornodes = 3;
-
- // The number of training examples per TensorNode.
- int32 batch_size = 4;
-
- // Each Embedding
- message TPUEmbeddingTable {
- // Name of the embedding table. This will be used to name Variables in the
- // Tensorflow Graph.
- string name = 1;
-
- // Number of rows of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 num_rows = 3;
-
- // Width of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 width = 4;
-
- // Number of distinct embedding activation vectors per training example
- // produced by lookups into this table during model evaluation. For each
- // table, the Graph will receive an activations Tensor of shape
- // (batch_size * table.num_features, table.width).
- // For example, num_features = 1 produces equivalent behavior to a single
- // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
- // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
- // embedding table rows, num_features is the number of vectors produced
- // after averaging. In sequence models num_features is typically equal
- // to the sequence length, since each sequence element must be represented
- // separately to the convolutional or recurrent network.
- int32 num_features = 5;
-
- OptimizationParameters optimization_parameters = 6;
- }
-
- repeated TPUEmbeddingTable table_config = 5;
-}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
new file mode 100644
index 0000000..da19b13
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
@@ -0,0 +1,95 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
+import "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto";
+
+message TPUEmbeddingConfiguration {
+ // Description of the various embedding tables.
+ message TableDescriptor {
+ // Name of the table.
+ string name = 1;
+ // Size of the vocabulary (i.e., number of rows) in the table.
+ int32 vocabulary_size = 2;
+ // The embedding dimension (i.e., the width of the embedding table).
+ int32 dimension = 3;
+ // Number of features mapped to this table.
+ int32 num_features = 4;
+ // Details of the learning algorithm used to update the embedding
+ // parameters.
+ OptimizationParameters optimization_parameters = 5;
+ }
+ repeated TableDescriptor table_descriptor = 1;
+
+ // Mode. Should the embedding layer program be run for inference (just forward
+ // pass), training (both forward and backward pass) or just the backward_pass.
+ enum Mode {
+ UNSPECIFIED = 0;
+ INFERENCE = 1;
+ TRAINING = 2;
+ BACKWARD_PASS_ONLY = 3;
+ }
+ Mode mode = 2;
+
+ // Number of samples in each batch of embedding layer activations sent to
+ // the TensorCore.
+ int32 batch_size_per_tensor_core = 3;
+
+ // Number of TPU hosts used for inference/training.
+ int32 num_hosts = 4;
+
+ // Number of TensorCore used for inference/training.
+ int32 num_tensor_cores = 5;
+
+ // Sharding strategy of the embedding tables among the hosts.
+ // If the sharding_strategy is "mod", each id is assigned to host
+ // "id % num_hosts". For instance, 13 ids are split across 5 hosts as:
+ // [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]].
+ // If the sharding_strategy is "div", ids are assigned to hosts in a
+ // contiguous manner. In this case, 13 ids are split across 5 hosts as:
+ // [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].
+ // In both the strategies, if the id space does not evenly divide the number
+ // of hosts, each of the first "table_descriptor.num_ids % num_hosts" hosts
+ // will be assigned one more id.
+ // This partitioning strategy exactly follows that in the embedding_lookup
+ // TensorFlow function at tensorflow/python/ops/embedding_ops.py.
+ enum ShardingStrategy {
+ DIV_DEFAULT = 0;
+ MOD = 1;
+ }
+ ShardingStrategy sharding_strategy = 6;
+
+ // This parameter determines if the execution of the sparse core will be
+ // pipelined with that of the TensorCore. This parameter only affects results
+ // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
+ // does not affect execution and hence, is a don't care value.
+ //
+ // false: The execution of the sparse core is not pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core is executed
+ // only after the backward pass of the previous step is complete. And the
+ // backward pass on the sparse core is executed only after the embedding
+ // gradients have been computed on the TensorCore on every step. This ensures
+ // that the activations on every step observe the gradient updates from the
+ // previous step on both the sparse core and the TensorCore.
+ //
+ // true: The execution of the sparse core is pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core can be
+ // executed after the forward pass of the previous step is complete without
+ // waiting for the backward pass. This improves the utilization of the sparse
+ // core allowing it to process step N+1 while the embedding gradients for step
+ // N are computed on the TensorCore. The backward pass of every step on the
+ // sparse core is executed directly after the forward pass for the next step
+ // is complete. The drawback is that embedding activations for step N+1 do not
+ // observe the embedding gradient updates from step N. This could affect model
+ // quality if step N and N+1 involve the same set of embedding IDs. However,
+ // since the embedding updates are sparse, this is generally not considered a
+ // problem.
+ bool pipeline_execution_with_tensor_core = 7;
+
+ // Extended output layout information; if not provided, a compatibility mode
+ // will use defaults that match the old layout. Providing a value for this
+ // field is EXPERIMENTAL and most ways of filling it will probably break. Do
+ // not set it unless you know what you are doing.
+ TPUEmbeddingOutputLayout output_layout = 8;
+}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
new file mode 100644
index 0000000..aed30b2
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// In the comments here, "layout" refers to the top-level EmbeddingOutputLayout
+// proto contained in the TPUEmbeddingConfiguration.
+
+// The embedding output consists of a list of tensors, each specified by an
+// EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output"
+// field). Each table and feature lookup is then placed into some number of
+// particular positions within some output tensor (identified by "tensor_index"
+// within OutputLocation). The tree of table lookups, feature lookups, and
+// output locations is specified by the
+// "table(table_id).feature(feature_id).output_location" repeated fields within
+// EmbeddingOutputLayout.
+
+message TPUEmbeddingOutputLayout {
+ // Location of one copy of the feature's data.
+ message OutputLocation {
+ // Which output tensor this copy of the feature will go into. Must be
+ // between 0 and layout.output_size().
+ int32 tensor_index = 1;
+
+ // Offset in dimension 0 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim0_size_per_sample().
+ int32 dim0_offset = 2;
+
+ // Offset in dimension 1 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim1_size() - table width; repeated or
+ // partially/fully overlapping values are allowed and results in the same
+ // range will be summed (with the gradients replicated in the backward
+ // pass).
+ int32 dim1_offset = 3;
+ }
+
+ // Description of the output placement for one feature.
+ message FeatureDescriptor {
+ // Typically, only one copy of each feature is used, but multiple are
+ // allowed and the same data will be copied to all of them (with the
+ // gradients summed in the backward pass).
+ repeated OutputLocation output_location = 1;
+ }
+
+ // Description of the output placement for features of one table.
+ message TableDescriptor {
+ // Output locations for each feature loaded from this table.
+ repeated FeatureDescriptor feature = 1;
+ }
+ // Output locations for each feature of each table.
+ repeated TableDescriptor table = 1;
+
+ // Data layout and shape computation information for a single output tensor.
+ // Any unused locations in the tensor will be filled with zeros, and
+ // corresponding gradients will be ignored.
+
+ // Size and layout information for 2-D tensors.
+ message TwoDOutputTensor {
+ // Multiplier for output dimension 0 size; used to match legacy format that
+ // stacks features within a sample in dimension 0.
+ int32 dim0_size_per_sample = 2;
+
+ // The size (in dimension 1) of this output tensor.
+ int32 dim1_size = 1;
+ }
+
+ // Format information for a single output tensor.
+ message EmbeddingOutputTensor {
+ oneof output_format {
+ TwoDOutputTensor two_d = 4;
+ }
+ }
+
+ // Shape and layout information for each tensor.
+ repeated EmbeddingOutputTensor output = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
index 471b1fa..b9e2a42 100644
--- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py
+++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
@@ -72,13 +72,12 @@
self._invert_topology(topology))
topology_rank = self._topology_tasks.ndim
- if core_assignment.ndim != topology_rank + 2:
- raise ValueError("core_assignment must be a rank {} numpy array".format(
- topology_rank + 2))
+ if core_assignment.ndim != 3:
+ raise ValueError("core_assignment must be a rank 3 numpy array, "
+ "got shape {}".format(core_assignment.shape))
self._num_replicas = core_assignment.shape[0]
- self._computation_shape = np.array(
- core_assignment.shape[1:-1], dtype=np.int32)
+ self._num_cores_per_replica = core_assignment.shape[1]
if core_assignment.shape[-1] != topology_rank:
raise ValueError(
@@ -107,18 +106,15 @@
"""Computes a nested dict which maps task and logical core to replicas."""
task_and_cores_to_replicas = {}
for replica in xrange(core_assignment.shape[0]):
- for dx in xrange(core_assignment.shape[1]):
- for dy in xrange(core_assignment.shape[2]):
- for dz in xrange(core_assignment.shape[3]):
- x, y, z = core_assignment[replica, dx, dy, dz, :]
- task_id = topology_tasks[x, y, z]
- if task_id not in task_and_cores_to_replicas:
- task_and_cores_to_replicas[task_id] = {}
- logical_core = (dx, dy, dz)
- if logical_core not in task_and_cores_to_replicas[task_id]:
- task_and_cores_to_replicas[task_id][logical_core] = set()
+ for logical_core in xrange(core_assignment.shape[1]):
+ x, y, z = core_assignment[replica, logical_core, :]
+ task_id = topology_tasks[x, y, z]
+ if task_id not in task_and_cores_to_replicas:
+ task_and_cores_to_replicas[task_id] = {}
+ if logical_core not in task_and_cores_to_replicas[task_id]:
+ task_and_cores_to_replicas[task_id][logical_core] = set()
- task_and_cores_to_replicas[task_id][logical_core].add(replica)
+ task_and_cores_to_replicas[task_id][logical_core].add(replica)
task_to_sorted_replica_id = {}
@@ -136,23 +132,9 @@
return self._topology
@property
- def computation_shape(self):
- """The computation shape.
-
- Returns:
- A rank-1 int32 numpy array with size equal to the TPU topology rank.
- Describes the logical shape in numbers of core of each replica of the
- computation in the TPU topology.
-
- Returns:
- The computation shape.
- """
- return self._computation_shape
-
- @property
def num_cores_per_replica(self):
"""The number of cores per replica."""
- return np.prod(self.computation_shape)
+ return self._num_cores_per_replica
@property
def num_replicas(self):
@@ -164,33 +146,22 @@
"""The logical to physical core mapping.
Returns:
- A numpy array of rank `topology_rank + 2`, with shape
- `[num_replicas] + computation_shape + [topology_rank]`. Maps
- (replica, logical core coordinates) pairs to physical topology
- coordinates.
+ An integer numpy array of rank 3, with shape
+ `[num_replicas, num_cores_per_replica, topology_rank]`. Maps
+ (replica, logical core) pairs to physical topology coordinates.
"""
return self._core_assignment
def _coordinates(self, replica, logical_core):
"""Returns the physical topology coordinates of a logical core."""
- if logical_core is None:
- logical_core = np.array([0, 0, 0], np.int32)
- else:
- logical_core = np.asarray(logical_core)
-
- if any(logical_core < 0) or any(logical_core >= self.computation_shape):
- raise ValueError("Invalid core {}; computation shape is {}".format(
- logical_core, self.computation_shape))
-
- logical_offset = tuple([replica] + logical_core.tolist() + [slice(3)])
- return tuple(self.core_assignment[logical_offset])
+ return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
"""Lookup replica ids by task number and logical core.
Args:
task_id: TensorFlow task number.
- logical_core: A tuple of three integers which represents a logical core.
+ logical_core: An integer, identifying a logical core.
Returns:
A sorted list of the replicas that are attached to that task and
logical_core.
@@ -205,17 +176,17 @@
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
- def tpu_ordinal(self, replica=0, logical_core=None):
+ def tpu_ordinal(self, replica=0, logical_core=0):
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return self._topology_devices[coordinates]
- def host_device(self, replica=0, logical_core=None, job=None):
+ def host_device(self, replica=0, logical_core=0, job=None):
"""Returns the CPU device attached to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return _tpu_host_device_name(job, self._topology_tasks[coordinates])
- def tpu_device(self, replica=0, logical_core=None, job=None):
+ def tpu_device(self, replica=0, logical_core=0, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return _tpu_device_name(job, self._topology_tasks[coordinates],
@@ -228,6 +199,8 @@
num_replicas=1):
"""Computes a device_assignment of a computation across a TPU topology.
+ Attempts to choose a compact grid of cores for locality.
+
Returns a `DeviceAssignment` that describes the cores in the topology assigned
to each core of each replica.
@@ -240,12 +213,12 @@
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
- computation_shape: A rank 1 int32 numpy array of size 3, describing the
- shape of the computation's block of cores. If None, the
- `computation_shape` is `[1, 1, 1]`.
- computation_stride: A rank 1 int32 numpy array of size 3, describing the
- inter-core spacing of the `computation_shape` cores in the TPU topology.
- If None, the `computation_stride` is `[1, 1, 1]`.
+ computation_shape: A rank 1 int32 numpy array with size equal to the
+ topology rank, describing the shape of the computation's block of cores.
+ If None, the `computation_shape` is `[1] * topology_rank`.
+ computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
+ describing the inter-core spacing of the `computation_shape` cores in the
+ TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
num_replicas: The number of computation replicas to run. The replicas will
be packed into the free spaces of the topology.
@@ -271,21 +244,21 @@
topology_rank = len(topology.mesh_shape)
mesh_shape = topology.mesh_shape
if computation_shape is None:
- computation_shape = np.array([1, 1, 1], dtype=np.int32)
+ computation_shape = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_shape = np.asarray(computation_shape, dtype=np.int32)
if computation_stride is None:
- computation_stride = np.array([1, 1, 1], dtype=np.int32)
+ computation_stride = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_stride = np.asarray(computation_stride, dtype=np.int32)
- if computation_shape.shape != (3,):
- raise ValueError("computation_shape must have shape [3]; got {}".format(
- computation_shape.shape))
- if computation_stride.shape != (3,):
- raise ValueError("computation_stride must have shape [3]; got {}".format(
- computation_stride.shape))
+ if computation_shape.shape != (topology_rank,):
+ raise ValueError("computation_shape must have shape [{}]; got {}".format(
+ topology_rank, computation_shape.shape))
+ if computation_stride.shape != (topology_rank,):
+ raise ValueError("computation_stride must have shape [{}]; got {}".format(
+ topology_rank, computation_stride.shape))
if any(computation_shape < 1):
raise ValueError(
@@ -315,28 +288,41 @@
num_replicas, max_replicas, computation_shape, computation_stride,
mesh_shape))
- # Choose a compact layout for the cores. Choose the smaller dimension in the
- # topology to be close to the square root of the number of replicas.
- num_chips = int(math.ceil(num_replicas / replica_counts[2]))
- target_size = int(math.ceil(math.sqrt(num_chips)))
+ def ceil_of_ratio(n, m):
+ return (n + m - 1) // m
- # Prefer an even size, if possible. Odd numbered rows head back towards the
- # first column, so it's best if the last row has an odd index.
- if target_size % 2 != 0:
- target_size -= 1
- y_size = min(replica_counts[1], target_size)
- if y_size * replica_counts[0] < num_chips:
- y_size = replica_counts[1]
+ replica_shape = [0] * topology_rank
+ if num_replicas > 0:
+ remaining_replicas = num_replicas
+ remaining_dims = topology_rank
+
+ # Choose dimensions as close to an equal cube as possible, in order of
+ # increasing dimension size. By visiting dimensions in increasing size, we
+ # assign the most constrained dimension first, so we won't make infeasible
+ # choices.
+ #
+ # As a secondary sort order, visit the dimensions in reverse order. This
+ # means we try to use both cores on the same chip in preference to two cores
+ # on different chips.
+ for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
+ i = -ni
+ target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
+ replica_shape[i] = min(target_size, x)
+ remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
+ remaining_dims -= 1
+
+ assert remaining_replicas == 1 and remaining_dims == 0
# Assigns an offset to each replica such that no two replicas overlap.
- replica_offsets = np.full([num_replicas, 3], -1, dtype=np.int32)
+ replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
for replica in xrange(num_replicas):
- # Chooses a replica number in X/Y/Z axes.
- z = replica % replica_counts[2]
- t = replica // replica_counts[2]
- y = t % y_size
- x = t // y_size
- replica_pos = np.array([x, y, z], dtype=np.int32)
+ # Chooses a replica number in each axis.
+ t = replica
+ pos = []
+ for dim in replica_shape[::-1]:
+ pos.append(t % dim)
+ t //= dim
+ replica_pos = np.array(pos[::-1], dtype=np.int32)
# Determines where that replica starts in each axis.
outer = replica_pos // computation_stride
@@ -351,6 +337,6 @@
indices = np.concatenate(
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
axis=-1)
- assignment = (
- indices + replica_offsets[:, np.newaxis, np.newaxis, np.newaxis, :])
+ indices = indices.reshape((-1, topology_rank))
+ assignment = indices + replica_offsets[:, np.newaxis, :]
return DeviceAssignment(topology, core_assignment=assignment)
diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py
index 1fb26e7..ab89c6a 100644
--- a/tensorflow/contrib/tpu/python/tpu/topology.py
+++ b/tensorflow/contrib/tpu/python/tpu/topology.py
@@ -112,6 +112,11 @@
return self._mesh_shape
@property
+ def mesh_rank(self):
+ """Returns the number of dimensions in the mesh."""
+ return len(self._mesh_shape)
+
+ @property
def device_coordinates(self):
"""Describes the mapping from TPU devices to topology coordinates.
@@ -125,6 +130,16 @@
"""
return self._device_coordinates
+ @property
+ def num_tasks(self):
+ """Returns the number of TensorFlow tasks in the TPU slice."""
+ return self._device_coordinates.shape[0]
+
+ @property
+ def num_tpus_per_task(self):
+ """Returns the number of TPU devices per task in the TPU slice."""
+ return self._device_coordinates.shape[1]
+
def serialized(self):
"""Returns the serialized form of the topology."""
if self._serialized is None:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 593f1d9..712b02f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -76,7 +76,7 @@
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
- embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
+ embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
job: The job (the XXX in TensorFlow device specification /job:XXX) that
@@ -562,13 +562,14 @@
device_assignment.core_assignment.flatten().tolist()
}
# TODO(phawkins): remove this case after the forward compatibility window
- # expires on 2018-10-6.
- if api_compat.forward_compatible(2018, 10, 6):
+ # expires on 2018-10-5.
+ if api_compat.forward_compatible(2018, 10, 5):
metadata_kwargs["num_cores_per_replica"] = (
device_assignment.num_cores_per_replica)
else:
- metadata_kwargs["computation_shape"] = (
- device_assignment.computation_shape.tolist())
+ metadata_kwargs["computation_shape"] = [
+ device_assignment.num_cores_per_replica
+ ]
if ((not isinstance(inputs, list)) or
any(not isinstance(inp, (list, tuple)) for inp in inputs)):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 18e0abd..9f8d147 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -32,7 +32,6 @@
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
-_NUM_CORES_PER_HOST = 8
# pylint: enable=protected-access
@@ -103,7 +102,7 @@
input mode.
Raises:
- ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
+ ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
@@ -139,9 +138,9 @@
# Check num_cores_per_replica
if num_cores_per_replica is not None:
- if num_cores_per_replica not in [1, 2, 4, 8]:
+ if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
- 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 2326fe9..b2fe0a6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -86,7 +86,7 @@
def test_fail_with_invalid_num_cores_per_replica(self):
with self.assertRaisesRegexp(
- ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
' got 7'):
tpu_config_lib.TPUConfig(num_cores_per_replica=7)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 19359cb..b1a8a16 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -35,7 +35,8 @@
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
- 8: [2, 2, 2]
+ 8: [2, 2, 2],
+ 16: [4, 2, 2],
}
@@ -298,6 +299,7 @@
@property
def num_of_replicas_per_host(self):
+ """Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
@@ -538,8 +540,8 @@
"""
if self.model_parallelism_enabled:
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
- replica = self.device_assignment.lookup_replicas(
- host_id, (0, 0, 0))[shard_index_in_host]
+ replica = self.device_assignment.lookup_replicas(host_id,
+ 0)[shard_index_in_host]
return self.device_assignment.tpu_ordinal(replica=replica)
else:
return shard_index_in_host % self.num_of_cores_per_host
@@ -580,6 +582,17 @@
raise ValueError(message)
+ if self._config.tpu_config.num_cores_per_replica:
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
+ num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
+ if num_cores_per_replica > num_cores_per_host:
+ raise ValueError(
+ 'The num of cores required by the model parallelism, specified by '
+ 'TPUConfig.num_cores_per_replica, is larger than the '
+ 'num_cores_per_host. num_cores_per_replica: {}, '
+ 'num_cores_per_host: {}'.format(num_cores_per_replica,
+ num_cores_per_host))
+
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
@@ -599,8 +612,8 @@
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
- 'TPUEstimator.evaluate should be running on single TPU worker. '
- 'got {}.'.format(num_hosts))
+ 'TPUEstimator.evaluate should be running on single TPU'
+ ' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index d9c77a3..e75a094 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -765,9 +765,8 @@
zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat)
]
- for core_index in xrange(self._device_assignment.num_cores_per_replica):
+ for logical_core in xrange(self._device_assignment.num_cores_per_replica):
# Places different partitions to different logic cores.
- logical_core = self._get_logical_core(core_index)
replica_id = self._device_assignment.lookup_replicas(
self._host_id, logical_core)[replica_index]
ordinal = self._device_assignment.tpu_ordinal(
@@ -784,7 +783,7 @@
inputs=infeed_inputs,
shapes=[x.shape for x in infeed_inputs],
name="enqueue/replica_{0}/input_{1}".format(
- replica_index, core_index),
+ replica_index, logical_core),
device_ordinal=ordinal))
return per_host_enqueue_ops
@@ -890,20 +889,3 @@
return nest.map_structure_up_to(
dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues,
dims)
-
- def _get_logical_core(self, core_index):
- """Maps the core index to the 3D coordinate within replica.
-
- The lowest dimension number in computation_shape is the slowest varying
- dimension (most major).
-
- Args:
- core_index: An integer represents the core index within replcia.
-
- Returns:
- A tuple with three integers which represents the 3D coordinate.
- """
- computation_shape = self._device_assignment.computation_shape
- return (core_index // (computation_shape[1] * computation_shape[2]),
- core_index % (computation_shape[1] * computation_shape[2]) //
- computation_shape[2], core_index % computation_shape[2])
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
index de16e3b..0c7a38d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
@@ -63,10 +63,9 @@
"""Validate the number of input arguments to a tpu function.
Args:
- func: the Python function that will be called to generate the body
- of a TPUFunction.
- input_arity: the number of explicit arguments supplied by the
- caller.
+ func: the Python function that will be called to generate the body of an XLA
+ computation graph.
+ input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
@@ -103,4 +102,3 @@
# Since there are varargs, func can accept any number of arguments
# greater than the minimum.
return None
-
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9bcf5b0..85b6d4f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1067,7 +1067,6 @@
"spectral_ops",
"state_ops",
"stateless_random_ops",
- "string_ops",
"summary_ops",
"training_ops",
],
@@ -1075,6 +1074,13 @@
tf_gen_op_libs(
op_lib_names = [
+ "string_ops",
+ ],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
@@ -1330,6 +1336,7 @@
"//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
+ "//tensorflow/core/kernels:searchsorted_op",
"//tensorflow/core/kernels:set_kernels",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
@@ -2095,6 +2102,7 @@
deps = tf_additional_lib_deps() + [
"@com_google_absl//absl/strings",
"//third_party/eigen3",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
)
@@ -2287,6 +2295,7 @@
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2319,6 +2328,7 @@
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2978,12 +2988,16 @@
] + tf_additional_device_tracer_deps(),
)
-cc_library(
- name = "session_ref",
- srcs = ["common_runtime/session_ref.cc"],
- hdrs = ["common_runtime/session_ref.h"],
- copts = tf_copts(),
- deps = [":core_cpu_base"],
+tf_proto_library_cc(
+ name = "replay_log_proto",
+ srcs = ["protobuf/replay_log.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":master_proto",
+ ] + tf_additional_all_protos(),
+ visibility = [
+ "//tensorflow:internal",
+ ],
)
cc_library(
diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
new file mode 100644
index 0000000..5ce825a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "LowerBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the first scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies lower_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='left')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = LowerBound(sorted_sequence, values)
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000..4cb8955
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+ graph_op_name: "PrintV2"
+ in_arg {
+ name: "input"
+ description: <<END
+The string scalar to print.
+END
+ }
+ attr {
+ name: "output_stream"
+ description: <<END
+A string specifying the output stream or logging level to print to.
+END
+ }
+ summary: "Prints a string scalar."
+ description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000..a82dae9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "StringFormat"
+ in_arg {
+ name: "inputs"
+ description: <<END
+The list of tensors to format into the placeholder string.
+END
+ }
+
+ out_arg {
+ name: "output"
+ description: <<END
+= The resulting string scalar.
+END
+ }
+ attr {
+ name: "template"
+ description: <<END
+A string, the template to format tensor summaries into.
+END
+ }
+ attr {
+ name: "placeholder"
+ description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+ }
+ attr {
+ name: "summarize"
+ description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+ }
+ summary: "Formats a string template using a list of tensors."
+ description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
new file mode 100644
index 0000000..0630f6e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "UpperBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the last scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies upper_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='right')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = UpperBound(sorted_sequence, values)
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000..e22d980
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "PrintV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000..8f0b1db
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringFormat"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 364071e..2d74bf2 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -304,7 +304,7 @@
};
// Returns 'bytes' rounded up to the next highest kMinAllocationSize.
- size_t RoundedBytes(size_t bytes);
+ static size_t RoundedBytes(size_t bytes);
// Try to add a new memory region that can satisfy an allocation of
// 'rounded_bytes' bytes. Returns true on success and false on
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 81d68e3..fb76d6a 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -106,6 +106,10 @@
// at completion.
virtual Status Sync() = 0;
+ // Override this to return true for devices that require a Sync() call before
+ // session completion.
+ virtual bool RequiresSyncOnCompletion() const { return false; }
+
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 8486539..9871954 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -76,56 +76,47 @@
namespace nodestats {
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
+void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
if (!stats) return;
stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
-void SetAllStart(NodeExecStatsWrapper* stats) {
+void SetAllStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorStarted();
}
-void SetOpStart(NodeExecStatsWrapper* stats) {
+void SetOpStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeStarted();
}
-void SetOpEnd(NodeExecStatsWrapper* stats) {
+void SetOpEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeEnded();
}
-void SetAllEnd(NodeExecStatsWrapper* stats) {
+void SetAllEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorEnded();
}
-void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
if (!stats) return;
stats->SetOutput(slot, v);
}
-void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
if (!stats) return;
stats->SetMemory(ctx);
}
-void SetReferencedTensors(NodeExecStatsWrapper* stats,
+void SetReferencedTensors(NodeExecStatsInterface* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
stats->SetReferencedTensors(tensors);
}
-// Sets the timeline_label field of *stats, using data from *node.
-// Returns true iff the node is a transfer node.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- if (!stats) {
- return false;
- }
- return stats->SetTimelineLabel(node);
-}
-
} // namespace nodestats
class ExecutorImpl;
@@ -1301,7 +1292,7 @@
// After item->kernel computation is done, processes its outputs.
Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStatsWrapper* stats);
+ EntryVector* outputs, NodeExecStatsInterface* stats);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
@@ -1312,7 +1303,7 @@
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
@@ -1513,7 +1504,7 @@
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem* _item, Entry* _first_input,
- NodeExecStatsWrapper* _stats)
+ NodeExecStatsInterface* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1538,7 +1529,7 @@
const NodeItem* item;
Entry* first_input;
OpKernelContext ctx;
- NodeExecStatsWrapper* stats;
+ NodeExecStatsInterface* stats;
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1583,7 +1574,7 @@
params.stats_collector = stats_collector_;
Status s;
- NodeExecStatsWrapper* stats = nullptr;
+ NodeExecStatsInterface* stats = nullptr;
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1613,7 +1604,7 @@
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper(node->name());
+ stats = stats_collector_->CreateNodeExecStats(node);
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
@@ -1671,7 +1662,7 @@
auto done = [this, state]() {
Device* device = impl_->params_.device;
- NodeExecStatsWrapper* stats = state->stats; // Shorthand
+ NodeExecStatsInterface* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
nodestats::SetOpEnd(stats);
@@ -1862,7 +1853,7 @@
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
- NodeExecStatsWrapper* stats) {
+ NodeExecStatsInterface* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -2080,16 +2071,15 @@
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr &&
- !nodestats::SetTimelineLabel(node, stats)) {
- // Only record non-transfer nodes.
- // Transfers 'stats' ownership to 'stats_collector_'.
- stats_collector_->Save(impl_->params_.device->name(), stats);
- } else if (stats) {
- delete stats;
+ if (stats) {
+ if (stats_collector_) {
+ stats->Done(impl_->params_.device->name());
+ } else {
+ delete stats;
+ }
}
bool abort_run = false;
@@ -2311,13 +2301,15 @@
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
- if (sync_on_finish_ && status.ok()) {
+ Device* device = impl_->params_.device;
+ if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- status = impl_->params_.device->Sync();
+ status.Update(device->Sync());
}
+
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 44ffce7..42021e5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,6 +22,39 @@
namespace tensorflow {
+bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) {
+ const char* force_allow_growth_string =
+ std::getenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ if (force_allow_growth_string == nullptr) {
+ return gpu_options.allow_growth();
+ }
+
+ if (strcmp("false", force_allow_growth_string) == 0) {
+ if (gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return false;
+ } else if (strcmp("true", force_allow_growth_string) == 0) {
+ if (!gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return true;
+ }
+
+ LOG(ERROR)
+ << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could"
+ << " not be parsed: \"" << force_allow_growth_string << "\". Valid"
+ << " values are \"true\" or \"false\". Using original config value"
+ << " of " << gpu_options.allow_growth() << ".";
+ return gpu_options.allow_growth();
+}
+
GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
size_t total_memory, const string& name)
: GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
@@ -30,7 +63,7 @@
size_t total_memory,
const GPUOptions& gpu_options,
const string& name)
- : BFCAllocator(sub_allocator, total_memory, gpu_options.allow_growth(),
- name) {}
+ : BFCAllocator(sub_allocator, total_memory,
+ GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index 6b6de80..d4c9cee 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -34,11 +34,11 @@
// Suballocator for GPU memory.
class GPUMemAllocator : public SubAllocator {
public:
- // 'cuda_gpu_id' refers to the ID of the GPU device within
+ // 'platform_gpu_id' refers to the ID of the GPU device within
// the process and must reference a valid ID in the process.
// Note: stream_exec cannot be null.
- explicit GPUMemAllocator(se::StreamExecutor* stream_exec, CudaGpuId gpu_id,
- bool use_unified_memory,
+ explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
+ PlatformGpuId gpu_id, bool use_unified_memory,
const std::vector<Visitor>& alloc_visitors,
const std::vector<Visitor>& free_visitors)
: SubAllocator(alloc_visitors, free_visitors),
@@ -76,7 +76,7 @@
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
- const CudaGpuId gpu_id_;
+ const PlatformGpuId gpu_id_;
const bool use_unified_memory_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
@@ -93,6 +93,9 @@
~GPUBFCAllocator() override {}
TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+
+ private:
+ static bool GetAllowGrowthValue(const GPUOptions& gpu_options);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index 7112c3a..60e82ed 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -47,10 +47,10 @@
}
TEST(GPUBFCAllocatorTest, NoDups) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
@@ -80,10 +80,10 @@
}
TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Allocate 256 raw pointers of sizes between 100 bytes and about
// a meg
@@ -142,10 +142,10 @@
}
TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
@@ -181,29 +181,29 @@
}
TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* ptr = a.Allocate<float>(0);
EXPECT_EQ(nullptr, ptr);
}
TEST(GPUBFCAllocatorTest, TracksSizes) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
@@ -212,10 +212,10 @@
}
TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
// Configure a 1MiB byte limit
GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc");
@@ -232,10 +232,10 @@
options.set_allow_growth(true);
// Max of 2GiB, but starts out small.
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc");
// Allocate 10 raw pointers of sizes between 100 bytes and about
@@ -297,14 +297,14 @@
}
TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc");
sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc");
void* amem = a.AllocateRaw(1, 1);
void* bmem = b.AllocateRaw(1, 1 << 30);
@@ -313,10 +313,10 @@
}
static void BM_Allocation(int iters) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<size_t> sizes = {256, 4096, 16384, 524288,
@@ -333,10 +333,10 @@
BENCHMARK(BM_Allocation);
static void BM_AllocationThreaded(int iters, int num_threads) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
thread::ThreadPool pool(Env::Default(), "test", num_threads);
std::atomic_int_fast32_t count(iters);
@@ -373,10 +373,10 @@
// A more complex benchmark that defers deallocation of an object for
// "delay" allocations.
static void BM_AllocationDelayed(int iters, int delay) {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024};
@@ -410,15 +410,17 @@
class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
protected:
+ void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); }
+
// The following test methods are called from tests. The reason for this is
// that this class is a friend class to BFCAllocator, but tests are not, so
// only methods inside this class can access private members of BFCAllocator.
void TestBinDebugInfo() {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
std::vector<void*> initial_ptrs;
@@ -497,10 +499,10 @@
}
void TestLog2FloorNonZeroSlow() {
- CudaGpuId cuda_gpu_id(0);
+ PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc");
EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0));
EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1));
@@ -510,6 +512,56 @@
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024));
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025));
}
+
+ void TestForceAllowGrowth() {
+ PlatformGpuId platform_gpu_id(0);
+ GPUOptions options;
+ // Unset flag value uses provided option.
+ unsetenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ options.set_allow_growth(true);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_0_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unset_flag_allocator.curr_region_allocation_bytes_);
+
+ // Unparseable flag value uses provided option.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_1_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unparsable_flag_allocator.curr_region_allocation_bytes_);
+
+ // Max of 2GiB total memory. Env variable set forces allow_growth, which
+ // does an initial allocation of 1MiB.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1);
+ options.set_allow_growth(false);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_2_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ force_allow_growth_allocator.curr_region_allocation_bytes_);
+
+ // If env variable forces allow_growth disabled, all available memory is
+ // allocated.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_3_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31),
+ force_no_allow_growth_allocator.curr_region_allocation_bytes_);
+ }
};
TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); }
@@ -518,6 +570,10 @@
TestLog2FloorNonZeroSlow();
}
+TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
+ TestForceAllowGrowth();
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 8e14f1e..d85ca88 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -28,9 +28,10 @@
namespace tensorflow {
GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator,
- CudaGpuId cuda_gpu_id)
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; }
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 3d1d0ef..8df3724 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -31,7 +31,8 @@
// allocated memory.
class GPUcudaMallocAllocator : public Allocator {
public:
- explicit GPUcudaMallocAllocator(Allocator* allocator, CudaGpuId cuda_gpu_id);
+ explicit GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUcudaMallocAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index 6bad66d..989ddbe 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -74,9 +74,10 @@
// GPUDebugAllocator
// -----------------------------------------------------------------------------
GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator,
- CudaGpuId cuda_gpu_id)
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
@@ -151,9 +152,10 @@
// GPUNanResetAllocator
// -----------------------------------------------------------------------------
GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator,
- CudaGpuId cuda_gpu_id)
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0f27ff4..17757a1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -33,7 +33,8 @@
// allocated memory.
class GPUDebugAllocator : public Allocator {
public:
- explicit GPUDebugAllocator(Allocator* allocator, CudaGpuId cuda_gpu_id);
+ explicit GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUDebugAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
@@ -62,7 +63,8 @@
// user forgets to initialize the memory.
class GPUNanResetAllocator : public Allocator {
public:
- explicit GPUNanResetAllocator(Allocator* allocator, CudaGpuId cuda_gpu_id);
+ explicit GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUNanResetAllocator() override;
string Name() override { return "gpu_nan_reset"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index 98283cd..aca08a7 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -34,13 +34,14 @@
namespace {
TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
for (int s : {8}) {
std::vector<int64> cpu_array(s);
@@ -61,14 +62,14 @@
for (int s : {8, 211}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
- cuda_gpu_id, false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id);
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -97,14 +98,14 @@
for (int s : {8, 22}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
- cuda_gpu_id, false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id);
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -130,13 +131,14 @@
}
TEST(GPUDebugAllocatorTest, ResetToNan) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -173,16 +175,17 @@
}
TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
// NaN reset must be the outer-most allocator.
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ platform_gpu_id),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -219,24 +222,24 @@
}
TEST(GPUDebugAllocatorTest, TracksSizes) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id);
+ platform_gpu_id);
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
- false /*use_unified_memory*/, {}, {});
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
+ platform_gpu_id),
+ platform_gpu_id);
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 50e61b7..cf3faf6 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -104,9 +104,9 @@
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
stream_ = cuda_stream;
allocator_ = alloc;
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()];
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
}
const cudaStream_t& stream() const override { return *stream_; }
@@ -342,9 +342,10 @@
gpu_device_info_->stream = streams_[0]->compute;
gpu_device_info_->default_context = device_contexts_[0];
gpu_device_info_->event_mgr = em_.get();
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- gpu_device_info_->gpu_id = cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ gpu_device_info_->gpu_id = platform_gpu_id.value();
set_tensorflow_gpu_device_info(gpu_device_info_);
// Whether and how the GPU device uses its own threadpool.
@@ -700,9 +701,9 @@
Eigen::GpuDevice device_;
};
-// Parse 'visible_device_list' into a list of CUDA GPU ids.
+// Parse 'visible_device_list' into a list of platform GPU ids.
Status ParseVisibleDeviceList(const string& visible_device_list,
- std::vector<CudaGpuId>* visible_gpu_order) {
+ std::vector<PlatformGpuId>* visible_gpu_order) {
visible_gpu_order->clear();
se::Platform* gpu_manager = GPUMachineManager();
@@ -717,26 +718,28 @@
} else {
const std::vector<string> order_str =
str_util::Split(visible_device_list, ',');
- for (const string& cuda_gpu_id_str : order_str) {
- int32 cuda_gpu_id;
- if (!strings::safe_strto32(cuda_gpu_id_str, &cuda_gpu_id)) {
+ for (const string& platform_gpu_id_str : order_str) {
+ int32 platform_gpu_id;
+ if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
- cuda_gpu_id_str, "'. visible_device_list = ", visible_device_list);
+ platform_gpu_id_str, "'. visible_device_list = ",
+ visible_device_list);
}
- if (cuda_gpu_id < 0 || cuda_gpu_id >= gpu_manager->VisibleDeviceCount()) {
+ if (platform_gpu_id < 0 ||
+ platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
return errors::InvalidArgument(
- "'visible_device_list' listed an invalid GPU id '", cuda_gpu_id,
+ "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
"' but visible device count is ",
gpu_manager->VisibleDeviceCount());
}
- visible_gpu_order->push_back(CudaGpuId(cuda_gpu_id));
+ visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id));
}
}
// Validate no repeats.
- std::set<CudaGpuId> visible_device_set(visible_gpu_order->begin(),
- visible_gpu_order->end());
+ std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(),
+ visible_gpu_order->end());
if (visible_device_set.size() != visible_gpu_order->size()) {
return errors::InvalidArgument(
"visible_device_list contained a duplicate entry: ",
@@ -747,8 +750,8 @@
Status VerifyVirtualDeviceSettings(
const size_t num_gpus_to_use, const GPUOptions& gpu_options,
- const std::vector<CudaGpuId>& visible_gpu_order,
- const std::vector<CudaGpuId>& valid_cuda_gpu_ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& valid_platform_gpu_ids) {
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
CHECK(!virtual_devices.empty());
if (gpu_options.per_process_gpu_memory_fraction() > 0) {
@@ -770,11 +773,11 @@
" #GPUs in visible_device_list: ", visible_gpu_order.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
- if (valid_cuda_gpu_ids.size() != virtual_devices.size()) {
+ if (valid_platform_gpu_ids.size() != virtual_devices.size()) {
return errors::Unknown(
"The number of valid GPUs doesn't match the number of elements in "
"the virtual_devices list.",
- " #valid GPUs: ", valid_cuda_gpu_ids.size(),
+ " #valid GPUs: ", valid_platform_gpu_ids.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
return Status::OK();
@@ -816,18 +819,18 @@
}
// Get the memory limit for the virtual device being created on GPU with
-// 'cuda_gpu_id', when that virtual device is the only virtual device being
+// 'platform_gpu_id', when that virtual device is the only virtual device being
// created on that GPU.
Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
- CudaGpuId cuda_gpu_id,
+ PlatformGpuId platform_gpu_id,
int64* memory_limit) {
int64 total_memory = 0;
int64 available_memory = 0;
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) {
return errors::Unknown("Failed to query available memory for GPU ",
- cuda_gpu_id.value());
+ platform_gpu_id.value());
}
int64 allocated_memory = 0;
@@ -928,8 +931,8 @@
num_gpus_to_use = iter->second;
}
const auto& gpu_options = options.config.gpu_options();
- std::vector<CudaGpuId> visible_gpu_order;
- std::vector<CudaGpuId> valid_cuda_gpu_ids;
+ std::vector<PlatformGpuId> visible_gpu_order;
+ std::vector<PlatformGpuId> valid_platform_gpu_ids;
// If we aren't going to use any GPUs, don't initialize them.
// We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
// because it treats an empty gpu_options.visible_device_list as 'all GPUs are
@@ -938,12 +941,12 @@
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
&visible_gpu_order));
TF_RETURN_IF_ERROR(
- GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
}
- if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
- num_gpus_to_use = valid_cuda_gpu_ids.size();
+ if (num_gpus_to_use > valid_platform_gpu_ids.size()) {
+ num_gpus_to_use = valid_platform_gpu_ids.size();
}
- if (!valid_cuda_gpu_ids.empty()) {
+ if (!valid_platform_gpu_ids.empty()) {
// Save the original device.
int original_device = 0;
cudaError_t err = cudaGetDevice(&original_device);
@@ -953,17 +956,18 @@
}
// Force to implicitly initialize CUDA runtime on each valid GPU before
// CreateGPUDevice().
- for (CudaGpuId cuda_gpu_id : valid_cuda_gpu_ids) {
- err = cudaSetDevice(cuda_gpu_id.value());
+ for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
+ err = cudaSetDevice(platform_gpu_id.value());
if (err != cudaSuccess) {
- return errors::Internal("cudaSetDevice() on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("cudaSetDevice() on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
err = cudaFree(nullptr);
if (err != cudaSuccess) {
- return errors::Internal(
- "CUDA runtime implicit initialization on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("CUDA runtime implicit initialization on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
}
// Reset to the original device.
@@ -989,10 +993,10 @@
LOG(INFO) << line_buf;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
line_buf = strings::StrCat(visible_gpu_order[i].value(), ": ");
- CudaGpuId cuda_id_i = visible_gpu_order[i];
+ PlatformGpuId gpu_id_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- CudaGpuId cuda_id_j = visible_gpu_order[j];
- if (im.directed_links.find({cuda_id_i, cuda_id_j}) !=
+ PlatformGpuId gpu_id_j = visible_gpu_order[j];
+ if (im.directed_links.find({gpu_id_i, gpu_id_j}) !=
im.directed_links.end()) {
line_buf.append("Y ");
} else {
@@ -1005,22 +1009,23 @@
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
if (!virtual_devices.empty()) {
- TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(
- num_gpus_to_use, gpu_options, visible_gpu_order, valid_cuda_gpu_ids));
+ TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(num_gpus_to_use, gpu_options,
+ visible_gpu_order,
+ valid_platform_gpu_ids));
// We've verified that num_gpus_to_use >= virtual_devices.size().
num_gpus_to_use = virtual_devices.size();
CHECK(gpu_options.visible_device_list().empty() ||
- valid_cuda_gpu_ids == visible_gpu_order);
+ valid_platform_gpu_ids == visible_gpu_order);
}
int next_tf_gpu_id = 0;
std::vector<int64> memory_limit_bytes;
for (int i = 0; i < num_gpus_to_use; ++i) {
- const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i];
+ const PlatformGpuId platform_gpu_id = valid_platform_gpu_ids[i];
if (virtual_devices.empty() ||
virtual_devices.Get(i).memory_limit_mb_size() == 0) {
int64 single_virtual_device_memory_limit = 0;
TF_RETURN_IF_ERROR(SingleVirtualDeviceMemoryLimit(
- gpu_options, cuda_gpu_id, &single_virtual_device_memory_limit));
+ gpu_options, platform_gpu_id, &single_virtual_device_memory_limit));
memory_limit_bytes.push_back(single_virtual_device_memory_limit);
} else {
const auto& memory_limit_mb = virtual_devices.Get(i).memory_limit_mb();
@@ -1033,7 +1038,7 @@
TfGpuId tf_gpu_id(next_tf_gpu_id);
++next_tf_gpu_id;
TF_RETURN_IF_ERROR(
- GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id));
+ GpuIdManager::InsertTfPlatformGpuIdPair(tf_gpu_id, platform_gpu_id));
}
}
const int num_tf_gpus = next_tf_gpu_id;
@@ -1058,7 +1063,7 @@
return Status::OK();
}
-static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
+static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
const se::DeviceDescription& desc) {
int cc_major;
int cc_minor;
@@ -1067,9 +1072,8 @@
cc_minor = 0;
}
// LINT.IfChange
- return strings::StrCat("device: ", cuda_gpu_id.value(),
- ", name: ", desc.name(),
- ", pci bus id: ", desc.pci_bus_id(),
+ return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ",
+ desc.name(), ", pci bus id: ", desc.pci_bus_id(),
", compute capability: ", cc_major, ".", cc_minor);
// LINT.ThenChange(//tensorflow/python/platform/test.py)
}
@@ -1084,12 +1088,13 @@
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
int numa_node = dev_locality.numa_node();
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
GPUProcessState* process_state = GPUProcessState::singleton();
Allocator* gpu_allocator = process_state->GetGPUAllocator(
@@ -1110,11 +1115,11 @@
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
- tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
- ProcessState::singleton()->GetCPUAllocator(numa_node));
+ tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
+ gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
- << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
+ << GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device);
@@ -1122,18 +1127,21 @@
}
namespace {
-std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>>
+std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>>
GetPeerAccessMap(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
- std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map(
- new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>);
- for (CudaGpuId cuda_gpu_i : visible_gpu_order) {
- for (CudaGpuId cuda_gpu_j : visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
+ std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> map(
+ new std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>);
+ for (PlatformGpuId platform_gpu_i : visible_gpu_order) {
+ for (PlatformGpuId platform_gpu_j : visible_gpu_order) {
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
- (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to);
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
+ (*map)[{platform_gpu_i, platform_gpu_j}] =
+ from->CanEnablePeerAccessTo(to);
}
}
@@ -1143,19 +1151,19 @@
} // namespace
Status BaseGPUDeviceFactory::GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager,
- std::vector<InterconnectMap>* maps) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ se::Platform* gpu_manager, std::vector<InterconnectMap>* maps) {
// The default interconnect map is obtained from the StreamExecutor.
auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
maps->resize(1);
InterconnectMap& imap = maps->at(0);
imap.name = "StreamExecutor";
imap.strength = InterconnectMap::kStreamExecutorStrength;
- for (CudaGpuId cuda_id_i : visible_gpu_order) {
- for (CudaGpuId cuda_id_j : visible_gpu_order) {
- if (cuda_id_i == cuda_id_j) continue;
- if ((*access_map)[{cuda_id_i, cuda_id_j}]) {
- imap.directed_links.insert({cuda_id_i, cuda_id_j});
+ for (PlatformGpuId gpu_id_i : visible_gpu_order) {
+ for (PlatformGpuId gpu_id_j : visible_gpu_order) {
+ if (gpu_id_i == gpu_id_j) continue;
+ if ((*access_map)[{gpu_id_i, gpu_id_j}]) {
+ imap.directed_links.insert({gpu_id_i, gpu_id_j});
}
}
}
@@ -1170,13 +1178,14 @@
all_tf_gpu_ids.push_back(TfGpuId(i));
}
for (TfGpuId tf_gpu_id : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
// Get GPU bus_id from its reported NUMA affinity. Because GPUs are
// virtualized in some environments, we can't just use the GPU id.
// NUMA locales are indexed from 0, buses are indexed from 1.
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
int numa_node = desc.numa_node();
if (numa_node < 0) {
@@ -1186,7 +1195,8 @@
// may run into trouble later with data transfer operations. The
// trouble may manifest as slower than expected performance, or
// outright failures.
- LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id
+ LOG(INFO) << "Could not identify NUMA node of platform GPU id "
+ << platform_gpu_id
<< ", defaulting to 0. Your kernel may not have been built "
<< "with NUMA support.";
numa_node = 0;
@@ -1199,10 +1209,10 @@
LocalLinks* links = dev_locality.mutable_links();
for (const InterconnectMap& imap : interconnects) {
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) !=
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (imap.directed_links.find({platform_gpu_id, platform_gpu_dst}) !=
imap.directed_links.end()) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
@@ -1216,10 +1226,10 @@
// add high strength links to the others.
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
if (tf_gpu_id == tf_gpu_dst) continue;
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (cuda_gpu_id == cuda_gpu_dst) {
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (platform_gpu_id == platform_gpu_dst) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
ilink->set_type("SAME_DEVICE");
@@ -1228,9 +1238,9 @@
}
(*localities)[tf_gpu_id] = dev_locality;
- VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id
- << " on bus " << dev_locality.bus_id() << " numa: " << numa_node
- << " pci: " << desc.pci_bus_id()
+ VLOG(1) << "GPUDevice PlatformGpuId " << platform_gpu_id << " TfGpuId "
+ << tf_gpu_id << " on bus " << dev_locality.bus_id()
+ << " numa: " << numa_node << " pci: " << desc.pci_bus_id()
<< " DeviceLocality: " << dev_locality.DebugString();
}
return Status::OK();
@@ -1238,14 +1248,14 @@
static int GetDefaultMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
static const int kDefaultMinGPUMultiprocessorCount = 8;
// Find the highest multi-processor count across all visible GPUs.
int max_count = -1;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_order[i]);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_order[i]);
if (!exec_status.ok()) {
continue;
}
@@ -1264,7 +1274,7 @@
static int GetMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
if (tf_min_gpu_core_count == nullptr ||
@@ -1342,18 +1352,20 @@
}
Status EnablePeerAccess(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
int possible_peer_count = 0;
int enabled_peer_count = 0;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_i = visible_gpu_order[i];
+ const PlatformGpuId platform_gpu_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- const CudaGpuId cuda_gpu_j = visible_gpu_order[j];
+ const PlatformGpuId platform_gpu_j = visible_gpu_order[j];
// We have already validated that ExecutorForDevice() calls return OK.
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
if (from->CanEnablePeerAccessTo(to)) {
++possible_peer_count;
@@ -1361,7 +1373,8 @@
if (!status.ok()) {
LOG(WARNING)
<< "Unable to enable peer access between device ordinals "
- << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status;
+ << platform_gpu_i << " and " << platform_gpu_j
+ << ", status: " << status;
} else {
++enabled_peer_count;
}
@@ -1384,22 +1397,23 @@
} // namespace
Status BaseGPUDeviceFactory::GetValidDeviceIds(
- const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids) {
se::Platform* gpu_manager = GPUMachineManager();
bool new_gpu_found = false;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
- // Only perform this once per visible cuda gpu id.
- if (visible_gpu_initialized_[cuda_gpu_id.value()]) {
+ // Only perform this once per visible platform gpu id.
+ if (visible_gpu_initialized_[visible_gpu_id.value()]) {
continue;
}
- visible_gpu_initialized_[cuda_gpu_id.value()] = true;
+ visible_gpu_initialized_[visible_gpu_id.value()] = true;
new_gpu_found = true;
- auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id);
+ auto executor =
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!executor.ok()) {
return executor.status();
}
@@ -1447,9 +1461,9 @@
// Filter out devices that don't have the right capability or power.
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId visible_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_id);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!exec_status.ok()) {
LOG(INFO) << "Ignoring visible gpu device " << visible_gpu_id
<< " whose executor is in invalid state: "
@@ -1498,7 +1512,7 @@
if (!ids->empty()) {
std::vector<int> raw_ids(ids->size());
std::transform(ids->begin(), ids->end(), raw_ids.begin(),
- [](CudaGpuId id) -> int { return id.value(); });
+ [](PlatformGpuId id) -> int { return id.value(); });
LOG(INFO) << "Adding visible gpu devices: "
<< str_util::Join(raw_ids, ", ");
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index b3eea55..b25fe86 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -90,12 +90,12 @@
DeviceContext* dc,
Allocator* allocator) override;
- // Returns the CUDA GPU id of this device within the native driver system;
+ // Returns the platform GPU id of this device within the native driver system;
// e.g., for CUDA this is the ordinal of the GPU within the system.
int gpu_id() const {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- return cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ return platform_gpu_id.value();
}
// The executor that provides control for the device; e.g., for CUDA this
@@ -173,14 +173,14 @@
int32 strength;
static const int kSameDeviceStrength;
static const int kStreamExecutorStrength;
- std::set<std::pair<CudaGpuId, CudaGpuId>> directed_links;
+ std::set<std::pair<PlatformGpuId, PlatformGpuId>> directed_links;
};
protected:
// Populates *maps with interconnect maps for all local direct access
// pathways between GPUs.
virtual Status GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& visible_gpu_order,
se::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
struct TfGpuIdHash {
@@ -212,16 +212,16 @@
Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
- // Returns into 'ids' the list of valid CUDA GPU ids, in the order that
+ // Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
// based upon 'visible_gpu_order' which was generated by parsing
// GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU
// ids.
- Status GetValidDeviceIds(const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids);
+ Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids);
- // visible_gpu_initialized_[cuda_gpu_id] is true if visible GPU cuda_gpu_id
- // has been initialized by the process.
+ // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
+ // platform_gpu_id has been initialized by the process.
std::unordered_map<int, bool> visible_gpu_initialized_;
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index daf59f0..3629409 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -30,18 +30,21 @@
namespace {
const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
-int64 GetTotalGPUMemory(CudaGpuId gpu_id) {
+int64 GetTotalGPUMemory(PlatformGpuId gpu_id) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
return total_memory;
}
-Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) {
+Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major,
+ int* cc_minor) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) {
*cc_major = 0;
*cc_minor = 0;
@@ -223,7 +226,7 @@
// error.
TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(PlatformGpuId(0), &cc_major, &cc_minor));
// Exit early while running on Pascal or later GPUs.
if (cc_major >= 6) {
return;
@@ -244,10 +247,10 @@
// more memory than what is available on the device.
TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
static constexpr double kGpuMemoryFraction = 1.2;
- static constexpr CudaGpuId kCudaGpuId(0);
+ static constexpr PlatformGpuId kPlatformGpuId(0);
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(kPlatformGpuId, &cc_major, &cc_minor));
// Exit early if running on pre-Pascal GPUs.
if (cc_major < 6) {
LOG(INFO)
@@ -262,7 +265,7 @@
ASSERT_EQ(1, devices.size());
int64 memory_limit = devices[0]->attributes().memory_limit();
- ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) *
+ ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kPlatformGpuId) *
kGpuMemoryFraction));
AllocatorAttributes allocator_attributes = AllocatorAttributes();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index 2a6caea..f0d9022 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -25,10 +25,10 @@
// physical machine, it can be filtered by CUDA environment variable
// CUDA_VISIBLE_DEVICES. Note that this id is not visible to Tensorflow, but
// result after filtering by CUDA_VISIBLE_DEVICES is visible to TF and is
-// called CUDA GPU id as below. See
+// called platform GPU id as below. See
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
// for more details.
-// - CUDA GPU id (also called *visible* GPU id in
+// - *platform* GPU id (also called *visible* GPU id in
// third_party/tensorflow/core/protobuf/config.proto): this is the id that is
// visible to Tensorflow after filtering by CUDA_VISIBLE_DEVICES, and is
// generated by the CUDA GPU driver. It starts from 0 and is used for CUDA API
@@ -39,14 +39,14 @@
// field of the device name "/device:GPU:<id>", and is also the identifier of
// a BaseGPUDevice. Note that the configuration allows us to create multiple
// BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the
-// hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1
+// hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1
// mapping, see the example below.
//
// For example, assuming that in the machine we have GPU device with index 0, 1,
// 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create
-// the following mapping between CUDA GPU id and physical GPU id:
+// the following mapping between platform GPU id and physical GPU id:
//
-// CUDA GPU id -> physical GPU id
+// platform GPU id -> physical GPU id
// 0 -> 1
// 1 -> 2
// 2 -> 3
@@ -56,32 +56,32 @@
//
// Assuming we configure the Session to create one BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 0
//
-// Note that CUDA GPU id 1 is filtered out by GPUOptions::visible_device_list,
-// so it won't be used by the TF process.
+// Note that platform GPU id 1 is filtered out by
+// GPUOptions::visible_device_list, so it won't be used by the TF process.
//
// On the other hand, if we configure it to create 2 BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 2
// 2 (i.e. /device:GPU:2) -> 0
// 3 (i.e. /device:GPU:3) -> 0
//
-// We create strong-typed integer classes for both TF GPU id and CUDA GPU id to
-// minimize programming errors and improve code readability. Except for the
+// We create strong-typed integer classes for both TF GPU id and platform GPU id
+// to minimize programming errors and improve code readability. Except for the
// StreamExecutor interface (as we don't change its API), whenever we need a
-// TF GPU id (or CUDA GPU id) we should use TfGpuId (or CudaGpuId) instead of a
-// raw integer.
+// TF GPU id (or platform GPU id) we should use TfGpuId (or PlatformGpuId)
+// instead of a raw integer.
TF_LIB_GTL_DEFINE_INT_TYPE(TfGpuId, int32);
-TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(PlatformGpuId, int32);
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index b5099dc..2b40730 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -26,26 +26,27 @@
namespace tensorflow {
namespace {
-// Manages the map between TfGpuId and CUDA GPU id.
-class TfToCudaGpuIdMap {
+// Manages the map between TfGpuId and platform GPU id.
+class TfToPlatformGpuIdMap {
public:
- static TfToCudaGpuIdMap* singleton() {
- static auto* id_map = new TfToCudaGpuIdMap;
+ static TfToPlatformGpuIdMap* singleton() {
+ static auto* id_map = new TfToPlatformGpuIdMap;
return id_map;
}
- Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) {
+ Status Insert(TfGpuId tf_gpu_id, PlatformGpuId platform_gpu_id)
+ LOCKS_EXCLUDED(mu_) {
std::pair<IdMapType::iterator, bool> result;
{
mutex_lock lock(mu_);
- result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()});
+ result = id_map_.insert({tf_gpu_id.value(), platform_gpu_id.value()});
}
- if (!result.second && cuda_gpu_id.value() != result.first->second) {
+ if (!result.second && platform_gpu_id.value() != result.first->second) {
return errors::AlreadyExists(
"TensorFlow device (GPU:", tf_gpu_id.value(),
") is being mapped to "
"multiple CUDA devices (",
- cuda_gpu_id.value(), " now, and ", result.first->second,
+ platform_gpu_id.value(), " now, and ", result.first->second,
" previously), which is not supported. "
"This may be the result of providing different GPU configurations "
"(ConfigProto.gpu_options, for example different visible_device_list)"
@@ -56,17 +57,17 @@
return Status::OK();
}
- bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const
+ bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
auto result = id_map_.find(tf_gpu_id.value());
if (result == id_map_.end()) return false;
- *cuda_gpu_id = result->second;
+ *platform_gpu_id = result->second;
return true;
}
private:
- TfToCudaGpuIdMap() = default;
+ TfToPlatformGpuIdMap() = default;
void TestOnlyReset() LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
@@ -78,17 +79,18 @@
IdMapType id_map_ GUARDED_BY(mu_);
friend class ::tensorflow::GpuIdManager;
- TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap);
+ TF_DISALLOW_COPY_AND_ASSIGN(TfToPlatformGpuIdMap);
};
} // namespace
-Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
- CudaGpuId cuda_gpu_id) {
- return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id);
+Status GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id) {
+ return TfToPlatformGpuIdMap::singleton()->Insert(tf_gpu_id, platform_gpu_id);
}
-Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
- if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) {
+Status GpuIdManager::TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id) {
+ if (TfToPlatformGpuIdMap::singleton()->Find(tf_gpu_id, platform_gpu_id)) {
return Status::OK();
}
return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(),
@@ -96,7 +98,7 @@
}
void GpuIdManager::TestOnlyReset() {
- TfToCudaGpuIdMap::singleton()->TestOnlyReset();
+ TfToPlatformGpuIdMap::singleton()->TestOnlyReset();
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
index 491d92c..62df431 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
@@ -21,15 +21,17 @@
namespace tensorflow {
-// Class that maintains a map from TfGpuId to CudaGpuId, and manages the
+// Class that maintains a map from TfGpuId to PlatformGpuId, and manages the
// translation between them.
class GpuIdManager {
public:
- // Adds a mapping from tf_gpu_id to cuda_gpu_id.
- static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
+ // Adds a mapping from tf_gpu_id to platform_gpu_id.
+ static Status InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id);
- // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found.
- static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id);
+ // Gets the platform_gpu_id associated with tf_gpu_id. Returns OK if found.
+ static Status TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id);
// Clears the map. Used in unit tests only.
static void TestOnlyReset();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
index a663ec7..8bf3c6a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
@@ -22,38 +22,38 @@
namespace tensorflow {
namespace {
-CudaGpuId TfToCudaGpuId(TfGpuId tf) {
- CudaGpuId cuda;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda));
- return cuda;
+PlatformGpuId TfToPlatformGpuId(TfGpuId tf) {
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf, &platform_gpu_id));
+ return platform_gpu_id;
}
TEST(GpuIdManagerTest, Basics) {
TfGpuId key_0(0);
- CudaGpuId value_0(0);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ PlatformGpuId value_0(0);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Multiple calls to map the same value is ok.
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Map a different TfGpuId to a different value.
TfGpuId key_1(3);
- CudaGpuId value_1(2);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_1));
+ PlatformGpuId value_1(2);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_1, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_1));
// Mapping a different TfGpuId to the same value is ok.
TfGpuId key_2(10);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_2));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_2));
// Mapping the same TfGpuId to a different value.
- ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok());
+ ASSERT_FALSE(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_0).ok());
// Getting a nonexistent mapping.
- ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok());
+ ASSERT_FALSE(GpuIdManager::TfToPlatformGpuId(TfGpuId(100), &value_0).ok());
}
} // namespace
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index b9c66b3..b1f10fb 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -24,34 +24,37 @@
namespace tensorflow {
-// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
+// Utility methods for translation between Tensorflow GPU ids and platform GPU
+// ids.
class GpuIdUtil {
public:
// Convenient methods for getting the associated executor given a TfGpuId or
- // CudaGpuId.
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
- return gpu_manager->ExecutorForDevice(cuda_gpu_id.value());
+ // PlatformGpuId.
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ se::Platform* gpu_manager, PlatformGpuId platform_gpu_id) {
+ return gpu_manager->ExecutorForDevice(platform_gpu_id.value());
}
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- CudaGpuId cuda_gpu_id) {
- return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id);
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ PlatformGpuId platform_gpu_id) {
+ return ExecutorForPlatformGpuId(GPUMachineManager(), platform_gpu_id);
}
static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- return ExecutorForCudaGpuId(cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ return ExecutorForPlatformGpuId(platform_gpu_id);
}
- // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate.
+ // Verify that the platform_gpu_id associated with a TfGpuId is legitimate.
static void CheckValidTfGpuId(TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
const int visible_device_count = GPUMachineManager()->VisibleDeviceCount();
- CHECK_LT(cuda_gpu_id.value(), visible_device_count)
- << "cuda_gpu_id is outside discovered device range."
- << " TF GPU id: " << tf_gpu_id << " CUDA GPU id: " << cuda_gpu_id
+ CHECK_LT(platform_gpu_id.value(), visible_device_count)
+ << "platform_gpu_id is outside discovered device range."
+ << " TF GPU id: " << tf_gpu_id
+ << " platform GPU id: " << platform_gpu_id
<< " visible device count: " << visible_device_count;
}
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index 9ec740f..3e95374 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -107,14 +107,15 @@
return nullptr;
}
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
int bus_id = BusIdForGPU(tf_gpu_id);
while (bus_id >= gpu_visitors_.size()) {
gpu_visitors_.push_back({});
}
GPUMemAllocator* sub_allocator = new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), cuda_gpu_id,
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id,
(options.per_process_gpu_memory_fraction() > 1.0 ||
options.experimental().use_unified_memory()),
gpu_visitors_[bus_id], {});
@@ -125,20 +126,21 @@
// If true, checks for memory overwrites by writing
// distinctive patterns on both ends of allocated memory.
if (useCudaMemoryGuardAllocator()) {
- gpu_allocator = new GPUDebugAllocator(gpu_allocator, cuda_gpu_id);
- gpu_allocator = new GPUNanResetAllocator(gpu_allocator, cuda_gpu_id);
+ gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
+ gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
} else if (useCudaMallocAllocator()) {
// If true, passes all allocation requests through to cudaMalloc
// useful for doing memory debugging with tools like cuda-memcheck
// **WARNING** probably will not work in a multi-gpu scenario
- gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, cuda_gpu_id);
+ gpu_allocator =
+ new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
}
Allocator* recording_allocator = nullptr;
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::GPU;
- md.dev_index = cuda_gpu_id.value();
+ md.dev_index = platform_gpu_id.value();
md.gpu_registered = false;
md.nic_registered = true;
recording_allocator = new internal::RecordingAllocator(
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index f9f3644..6af4ca4 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -50,8 +50,8 @@
}
for (Node* n : matches) {
AttrSlice n_attrs = n->attrs();
- auto base_make_node = [n, g, &n_attrs](const string& op,
- const string& name) {
+ auto base_make_node = [n, &n_attrs](const string& op,
+ const string& name) {
NodeBuilder node_builder(name, op);
node_builder.Device(n->requested_device());
string colo;
@@ -60,7 +60,7 @@
}
return node_builder;
};
- auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
+ auto make_node = [n, g, &base_make_node](string op) {
return base_make_node(
op, g->NewName(strings::StrCat(n->name(), "/Internal")));
};
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 9d59264..c00789a 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -58,6 +58,15 @@
return underlying_->GetAllocator(attr);
}
+ Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) override {
+ return underlying_->GetScopedAllocator(attr, step_id);
+ }
+
+ ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+ return underlying_->GetScopedAllocatorMgr();
+ }
+
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
deleted file mode 100644
index b931ef4..0000000
--- a/tensorflow/core/common_runtime/session_ref.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/session_ref.h"
-
-#include <utility>
-
-namespace tensorflow {
-
-namespace {
-
-// Scope helper to track active calls and manage session lifetime.
-struct RunCounter {
- std::shared_ptr<Session> session;
- uint64* value;
- mutex* m;
- condition_variable* cv;
-
- explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
- condition_variable* cv)
- : session(std::move(s)), value(v), m(m), cv(cv) {
- mutex_lock l(*m);
- ++*value;
- }
-
- ~RunCounter() {
- mutex_lock l(*m);
- if (--*value == 0) {
- cv->notify_all();
- }
- }
-};
-
-} // namespace
-
-Status SessionRef::CheckNotClosed() {
- mutex_lock l(run_lock_);
- if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
- return ::tensorflow::Status::OK();
-}
-
-Status SessionRef::Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(run_options, inputs, output_tensor_names,
- target_node_names, outputs, run_metadata);
-}
-
-Status SessionRef::Create(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(graph);
-}
-
-Status SessionRef::Create(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(run_options, graph);
-}
-
-Status SessionRef::Extend(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(run_options, graph);
-}
-
-Status SessionRef::Extend(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(graph);
-}
-
-Status SessionRef::Close(const RunOptions& run_options) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close(run_options);
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Close() {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close();
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(inputs, output_tensor_names, target_node_names,
- outputs);
-}
-
-Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ListDevices(response);
-}
-
-Status SessionRef::PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
-}
-
-Status SessionRef::PRun(const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRun(handle, inputs, output_names, outputs);
-}
-
-Status SessionRef::MakeCallable(const CallableOptions& callable_options,
- CallableHandle* out_handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->MakeCallable(callable_options, out_handle);
-}
-
-Status SessionRef::RunCallable(CallableHandle handle,
- const std::vector<Tensor>& feed_tensors,
- std::vector<Tensor>* fetch_tensors,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
- run_metadata);
-}
-
-Status SessionRef::ReleaseCallable(CallableHandle handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ReleaseCallable(handle);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 836cb8e..a70ab93 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -27,6 +27,7 @@
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
@@ -40,46 +41,24 @@
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
- : NodeExecStatsWrapper(new NodeExecStats) {
- stats_->set_node_name(node_name);
-}
-NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
- : stats_(stats) {}
-
-void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
- DCHECK(v);
- NodeOutput* no = stats_->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ const Node* node, StepStatsCollector* step_stats_collector)
+ : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
+ step_stats_collector) {
+ stats_->set_node_name(node->name());
}
-void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats_->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
-}
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector)
+ : stats_(std::move(stats)),
+ node_(node),
+ step_stats_collector_(step_stats_collector) {}
-void NodeExecStatsWrapper::SetReferencedTensors(
- const TensorReferenceVector& tensors) {
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description = stats_->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
- }
-}
-
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
- bool is_transfer_node = false;
+void NodeExecStatsWrapper::Done(const string& device) {
+ // TODO(tucker): merge with the DetailText function in session.cc in a common
+ // location.
+ DCHECK(node_);
string memory;
for (auto& all : stats_->memory()) {
int64 tot = all.total_bytes();
@@ -96,31 +75,96 @@
}
}
}
- const AttrSlice attrs = node->attrs();
+ const AttrSlice attrs = node_->attrs();
string text;
- if (IsSend(node)) {
+ if (IsSend(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string recv_device;
TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
+ } else if (IsRecv(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string send_device;
TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", send_device);
- is_transfer_node = true;
} else {
text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
+ strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
+ str_util::Join(node_->requested_inputs(), ", "), ")");
}
stats_->set_timeline_label(text);
- return is_transfer_node;
+ step_stats_collector_->Save(device, this);
+}
+
+void NodeExecStatsWrapper::RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+}
+
+void NodeExecStatsWrapper::RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
+ DCHECK(tensor);
+ NodeOutput* node_output = stats_->add_output();
+ node_output->set_slot(slot);
+ tensor->FillDescription(node_output->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
}
void NodeExecStatsWrapper::AddAllocation(
@@ -150,8 +194,8 @@
allocations_.clear();
}
-StepStatsCollector::StepStatsCollector(StepStats* ss)
- : finalized_(false), step_stats_(ss) {}
+StepStatsCollector::StepStatsCollector(StepStats* step_stats)
+ : finalized_(false), step_stats_(step_stats) {}
static int ExtractGpuWithStreamAll(string device_name) {
// Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -338,30 +382,42 @@
}
}
-void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
- Save(device, new NodeExecStatsWrapper(nt));
+void StepStatsCollector::Save(const string& device,
+ NodeExecStats* node_stats_pb) {
+ Save(device,
+ new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
+ nullptr, this));
}
void StepStatsCollector::Save(const string& device,
- NodeExecStatsWrapper* stats) {
- if (!stats) return;
- VLOG(1) << "Save dev " << device << " nt " << stats->stats();
+ NodeExecStatsWrapper* node_stats) {
+ if (!node_stats) return;
+ VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
{
mutex_lock l(mu_);
if (finalized_) {
LOG(WARNING) << "stats saved after finalize will not be collected.";
}
- if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+ if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
- delete stats;
+ delete node_stats;
return;
}
- auto& dss = dev_stats_[device];
- dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
- collectedNodes++;
+ auto& device_stats = dev_stats_[device];
+ device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
+ collected_nodes_++;
}
}
+NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
+ const Node* node) {
+ // Only collect statistics for non-transfer nodes.
+ if (IsSend(node) || IsRecv(node)) {
+ return nullptr;
+ }
+ return new NodeExecStatsWrapper(node, this);
+}
+
string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
mutex_lock l(mu_);
if (err.find("OOM") == err.npos) {
@@ -446,12 +502,12 @@
FinalizeInternal();
}
-void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
+void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
mutex_lock l(mu_);
CHECK(step_stats_);
FinalizeInternal();
- ss->Swap(step_stats_);
- collectedNodes = 0;
+ step_stats->Swap(step_stats_);
+ collected_nodes_ = 0;
}
void StepStatsCollector::FinalizeInternal() {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 7206fbf..4365b11 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -36,81 +36,78 @@
class NodeExecStats;
class OpKernelContext;
class StepStats;
+class StepStatsCollector;
class Tensor;
class TrackingAllocator;
-// Wraps NodeExecStats and adds allocation to it.
-class NodeExecStatsWrapper {
+// Statistics collection interface for individual node execution.
+//
+// See `NodeExecStatsWrapper` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class NodeExecStatsInterface {
public:
- NodeExecStatsWrapper(const string& node_name);
- // Owns 'stats'.
- NodeExecStatsWrapper(NodeExecStats* stats);
+ virtual ~NodeExecStatsInterface() {}
- // Destructor calls Finalize() to release the TrackingAllocators.
- ~NodeExecStatsWrapper() { Finalize(); }
-
- // Records the absolute time in nanoseconds at which this node became
- // runnable (i.e. was scheduled for execution).
- void SetScheduled(int64 nanos) {
- stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats_->set_scheduled_nanos(nanos);
- }
+ // Called when the statistics collection for the node has finished. Once this
+ // method is called, the caller should not make assumptions about the validity
+ // of this object.
+ virtual void Done(const string& device) = 0;
// Called immediately after this node starts being processed by the executor.
- void RecordExecutorStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats_->set_all_start_nanos(now_nanos);
- }
+ virtual void RecordExecutorStarted() = 0;
// Called immediately before this node's `Compute()` or `ComputeAsync()`
// method is called.
- void RecordComputeStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeStarted() = 0;
// Called immediately after this node's `Compute()` method returned (or, for
// asynchronous operations, the callback passed to its `ComputeAsync()` method
// was called).
- void RecordComputeEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeEnded() = 0;
// Called immediately after this executor finishes processing this node.
- void RecordExecutorEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
-
- // Records information about the tensor produced by this node at the given
- // output slot.
- void SetOutput(int slot, const Tensor* v);
+ virtual void RecordExecutorEnded() = 0;
// Records information about the memory allocated during the execution of this
// node.
- void SetMemory(OpKernelContext* ctx);
+ virtual void SetMemory(OpKernelContext* ctx) = 0;
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ virtual void SetOutput(int slot, const Tensor* tensor) = 0;
// Records information about the tensors that were accessed during the
// execution of this node.
- void SetReferencedTensors(const TensorReferenceVector& tensors);
+ virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0;
- // Sets the timeline_label field of the wrapped NodeExecStats, using data
- // from *node. Returns true iff the node is a transfer node.
- bool SetTimelineLabel(const Node* node);
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ virtual void SetScheduled(int64 nanos) = 0;
+};
+
+// Wraps NodeExecStats and adds allocation to it.
+class NodeExecStatsWrapper : public NodeExecStatsInterface {
+ public:
+ // Does not take ownership of `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Destructor calls Finalize() to release the TrackingAllocators.
+ ~NodeExecStatsWrapper() { Finalize(); }
+
+ void Done(const string& device) override;
+ void RecordExecutorStarted() override;
+ void RecordComputeStarted() override;
+ void RecordComputeEnded() override;
+ void RecordExecutorEnded() override;
+ void SetMemory(OpKernelContext* ctx) override;
+ void SetOutput(int slot, const Tensor* tensor) override;
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override;
+ void SetScheduled(int64 nanos) override;
private:
friend class StepStatsCollector;
@@ -128,9 +125,11 @@
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
+ const Node* const node_; // Not owned.
+ StepStatsCollector* const step_stats_collector_; // Not owned.
};
-// Statistics collection interface for individual node execution.
+// Statistics collection interface for step execution.
//
// See `StepStatsCollector` for a concrete implementation of this interface
// that interfaces with the `Session` layer.
@@ -138,8 +137,9 @@
public:
virtual ~StepStatsCollectorInterface() {}
- // Saves `stats` to the collector.
- virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+ // Creates an instance of `NodeExecStatsInterface` that should be used for
+ // collecting statistics about individual node execution.
+ virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0;
// Generates a string reporting the currently used memory based
// on ResourceExhausted OOM `err` message.
@@ -154,8 +154,8 @@
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector : public StepStatsCollectorInterface {
public:
- // Does not take ownership of `ss`.
- explicit StepStatsCollector(StepStats* ss);
+ // Does not take ownership of `step_stats`.
+ explicit StepStatsCollector(StepStats* step_stats);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
@@ -164,11 +164,12 @@
CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map);
- // Save saves nt to the DeviceStats object associated with device.
+ // Saves node statistics to the DeviceStats object associated with device.
// Should be called before Finalize.
- void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats) override;
+ void Save(const string& device, NodeExecStats* node_stats_pb);
+ void Save(const string& device, NodeExecStatsWrapper* node_stats);
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
@@ -176,20 +177,22 @@
// User shouldn't call Save() methods after Finalize.
void Finalize();
// swaps the content of StepStats* from constructor with 'ss'.
- void FinalizeAndSwap(StepStats* ss);
+ void FinalizeAndSwap(StepStats* step_stats);
private:
+ // TODO(suharshs): Make this configurable if its not possible to find a value
+ // that works for all cases.
+ static const uint64 kMaxCollectedNodes = 1 << 20;
+
+ typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+
void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
- typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
- // TODO(suharshs): Make this configurable if its not possible to find a value
- // that works for all cases.
- const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
bool finalized_ GUARDED_BY(mu_);
- std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
+ std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
- uint64 collectedNodes GUARDED_BY(mu_) = 0;
+ uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index ec93b9a..016d1a9 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -103,6 +103,7 @@
#include <iterator>
#include <type_traits>
+#include "absl/base/macros.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -113,10 +114,10 @@
namespace internal {
-// DEPRECATED: Use GetFeature instead.
// TODO(gorban): Update all clients in a followup CL.
// Returns a reference to a feature corresponding to the name.
// Note: it will create a new Feature if it is missing in the example.
+ABSL_DEPRECATED("Use GetFeature instead.")
Feature& ExampleFeature(const string& name, Example* example);
// Specializations of RepeatedFieldTrait define a type of RepeatedField
@@ -314,9 +315,9 @@
return HasFeature<FeatureType...>(key, GetFeatures(example));
}
-// DEPRECATED: use HasFeature instead.
// TODO(gorban): update all clients in a followup CL.
template <typename... FeatureType>
+ABSL_DEPRECATED("Use HasFeature instead.")
bool ExampleHasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, example);
}
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40..af59500 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@
}
}
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+ mutex_lock lock(mu_);
+ if (is_cancelled_ || is_cancelling_) {
+ return false;
+ } else {
+ callbacks_.erase(token);
+ return true;
+ }
+}
+
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index acdaaf6..7a5d942 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -122,6 +122,15 @@
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // immediately, with no guarantee that the callback has completed.
+ //
+ // This method is guaranteed to return true if StartCancel has not been
+ // called.
+ bool TryDeregisterCallback(CancellationToken token);
+
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f1824..bf7593b 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@
delete cm;
}
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+ Notification cancel_started, finish_callback, cancel_complete;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(token, [&]() {
+ cancel_started.Notify();
+ finish_callback.WaitForNotification();
+ });
+ EXPECT_TRUE(registered);
+
+ thread::ThreadPool w(Env::Default(), "test", 1);
+ w.Schedule([&]() {
+ manager->StartCancel();
+ cancel_complete.Notify();
+ });
+ cancel_started.WaitForNotification();
+
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+
+ finish_callback.Notify();
+ cancel_complete.WaitForNotification();
+ delete manager;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 91b1e61..697e060 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -529,25 +529,11 @@
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
if (ctx->model()) {
- // The prefix might contain an index. We need to strip it to make it
- // possible for the model to successfully identify the output node.
- string sanitized_prefix = prefix;
- if (str_util::EndsWith(prefix, "]")) {
- sanitized_prefix = prefix.substr(0, prefix.rfind('['));
- }
- std::shared_ptr<model::Node> node =
- ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
- std::vector<string> tokens =
- str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
- node->set_name(tokens[tokens.size() - 1]);
+ ctx->model()->AddNode((*iterator)->prefix(), prefix);
std::shared_ptr<model::Model> model = ctx->model();
const string& prefix = (*iterator)->prefix();
- (*iterator)->AddCleanupFunction([model, node, prefix]() {
- if (node->output()) {
- node->output()->remove_input(node);
- }
- model->RemoveNode(prefix);
- });
+ (*iterator)->AddCleanupFunction(
+ [model, prefix]() { model->RemoveNode(prefix); });
}
return (*iterator)->Initialize(ctx);
}
@@ -629,23 +615,10 @@
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
- Status s;
- if (ctx->model()) {
- std::shared_ptr<model::Node> node =
- ctx->model()->LookupNode(params_.prefix);
- if (node->output()) {
- node->output()->stop_work();
- }
- node->start_work();
- s = GetNextInternal(ctx, out_tensors, end_of_sequence);
- node->stop_work();
- node->add_element();
- if (node->output()) {
- node->output()->start_work();
- }
- } else {
- s = GetNextInternal(ctx, out_tensors, end_of_sequence);
- }
+ RecordStart(ctx, true /* stop_output */);
+ Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+ RecordStop(ctx, true /* start_output */);
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -677,52 +650,46 @@
void AddConstantParameter(IteratorContext* ctx, const string& name,
int64 value) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->add_constant_param(name, value);
- }
+ ctx->model()->AddConstantParameter(prefix(), name, value);
}
}
// When performance modeling is enabled, this method adds a tunable parameter
// to the model node corresponding to this iterator.
//
- // The `set_fn` function should set the tunable parameter to the value of
- // its input argument. The function should be thread-safe; in particular, the
- // state it updates should be protected by a lock as the function can be
- // invoked asynchronously. It is guaranteed that this function will not be
- // invoked after the iterator is deleted because the model node that owns
- // the function is deleted when the iterator is deleted.
+ // The performance modeling logic may use `value` to set the value of the
+ // tunable parameter at any point during the lifetime of this iterator. When
+ // it does, it notifies `cond_var`.
void AddTunableParameter(IteratorContext* ctx, const string& name,
- int64 value, int64 min, int64 max,
- std::function<void(int64)>&& set_fn) {
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->add_tunable_param(name, value, min, max, std::move(set_fn));
- }
+ ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
+ cond_var);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // this iterator has produced an element.
+ void RecordElement(IteratorContext* ctx) {
+ if (ctx->model()) {
+ ctx->model()->RecordElement(prefix());
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has started work.
- void StartWork(IteratorContext* ctx) {
+ void RecordStart(IteratorContext* ctx, bool stop_output = false) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->start_work();
- }
+ ctx->model()->RecordStart(prefix(), stop_output);
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has stopped work.
- void StopWork(IteratorContext* ctx) {
+ void RecordStop(IteratorContext* ctx, bool start_output = false) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->stop_work();
- }
+ ctx->model()->RecordStop(prefix(), start_output);
}
}
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 53ac639..446c31b 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -20,6 +20,7 @@
#include <string>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -176,9 +177,9 @@
return nullptr;
}
- // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
// This method is provided for backwards compatibility, and will be removed
// in a future release.
+ ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.")
Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
return GetAllocator(attr);
}
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 112298c..b0330ec 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,16 +17,14 @@
#include <memory>
-#include "tensorflow/core/lib/gtl/map_util.h"
-
namespace tensorflow {
namespace data {
namespace model {
// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
-void Node::CollectTunables(
+void Model::Node::CollectTunables(
std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (auto input : inputs_) {
input->CollectTunables(tunables);
}
@@ -45,14 +43,14 @@
}
}
-int64 Node::GetParameterValue(const string& name) {
+int64 Model::Node::GetParameterValue(const string& name) {
if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
return (*tunable_param)->value;
}
return constant_params_[name];
}
-int64 Node::ProcessingTimeLocked() {
+int64 Model::Node::ProcessingTimeLocked() {
switch (type_) {
case Type::BATCH:
case Type::MAP_AND_BATCH:
@@ -101,7 +99,7 @@
}
}
-int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
switch (type_) {
case Type::BATCH:
case Type::PADDED_BATCH: {
@@ -251,15 +249,34 @@
}
}
-std::shared_ptr<Node> Model::AddNode(const string& name,
- const string& output_name) {
- mutex_lock l(mu_);
+void Model::AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, node_name);
+ if (node) {
+ (*node)->add_constant_param(parameter_name, value);
+ }
+}
+
+void Model::AddNode(const string& name, const string& output_name) {
+ // The name captures the sequence of iterators joined by `::`. We use the full
+ // sequence as the key in the lookup table, but only the last element of the
+ // sequence as the name node.
+ std::vector<string> tokens =
+ str_util::Split(name, ':', str_util::SkipEmpty());
+ // The output name might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_output_name = output_name;
+ if (str_util::EndsWith(output_name, "]")) {
+ sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+ }
std::shared_ptr<Node> output;
- auto it = lookup_table_.find(output_name);
+ mutex_lock l(mu_);
+ auto it = lookup_table_.find(sanitized_output_name);
if (it != lookup_table_.end()) {
output = it->second;
}
- std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
if (!output_) {
output_ = node;
}
@@ -267,88 +284,125 @@
output->add_input(node);
}
lookup_table_.insert(std::make_pair(name, node));
- return node;
}
-std::shared_ptr<Node> Model::LookupNode(const string& name) {
+void Model::AddProcessingTime(const string& name, int64 delta) {
tf_shared_lock l(mu_);
- std::shared_ptr<Node> result;
- auto it = lookup_table_.find(name);
- if (it != lookup_table_.end()) {
- result = it->second;
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->add_processing_time(delta);
}
- return result;
+}
+
+void Model::AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ tf_shared_lock l(mu_);
+ auto node = *gtl::FindOrNull(lookup_table_, node_name);
+ DCHECK(node);
+ node->add_tunable_param(parameter_name, value, min, max, cond_var);
}
// The optimization algorithm starts by setting all tunable parallelism
-// parameters to 1. It then repeatedly identifies the parameter that whose
-// increase in parallelism decreases the output time the most. This process is
-// repeated until all parameters reach their maximum values or the
-// projected output time is less than or equal to the processing time needed to
-// produce an element divided by CPU budget.
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
void Model::Optimize(int64 cpu_budget) {
- mutex_lock l(optimization_mu_);
- std::vector<std::shared_ptr<Node::Tunable>> tunables;
- {
- mutex_lock l2(mu_);
- const int64 processing_time = ProcessingTime();
- tunables = CollectTunables();
- for (auto tunable : tunables) {
- tunable->value = 1;
- }
- while (true) {
- const int64 output_time = OutputTime();
- bool all_tunables = true;
- for (auto& tunable : tunables) {
- if (tunable->value < tunable->max) {
- all_tunables = false;
- break;
- }
- }
- if (output_time < processing_time / cpu_budget || all_tunables) {
- break;
- }
- int64 best_delta = -1;
- Node::Tunable* best_tunable = nullptr;
- for (auto& tunable : tunables) {
- if (tunable->value == tunable->max) {
- continue;
- }
- tunable->value++;
- int64 delta = output_time - OutputTime();
- if (delta > best_delta) {
- best_delta = delta;
- best_tunable = tunable.get();
- }
- tunable->value--;
- }
- if (!best_tunable) {
- // NOTE: This can happen because we are performing the optimization
- // while the model data is changing. If this becomes an issue, we should
- // look into performing the optimization using a model snapshot.
- break;
- }
- best_tunable->value++;
- }
+ tf_shared_lock lock(mu_);
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
}
- // The `set_fn` functions should be invoked without holding a lock to avoid a
- // potential deadlock.
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
+ break;
+ }
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
+ }
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
+ }
+ best_tunable->value++;
+ }
+ VLOG(2) << "Number of knobs: " << tunables.size();
for (auto& tunable : tunables) {
- tunable->set_fn(tunable->value);
+ VLOG(2) << "Setting tunable parameter: " << tunable->value;
+ tunable->value_ptr->store(tunable->value);
+ if (tunable->cond_var) {
+ tunable->cond_var->notify_all();
+ }
}
}
-void Model::RemoveNode(const string& prefix) {
- // Nodes are not allowed to be removed when optimization is in progress to
- // prevent the optimization from trying to access an iterator that was
- // concurrently deleted.
- mutex_lock l(optimization_mu_);
- mutex_lock l2(mu_);
- lookup_table_.erase(prefix);
+void Model::RecordElement(const string& name) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_element();
+ }
}
-std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() {
- std::vector<std::shared_ptr<Node::Tunable>> tunables;
+void Model::RecordStart(const string& name, bool stop_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ if (stop_output && (*node)->output()) {
+ (*node)->output()->record_stop();
+ }
+ (*node)->record_start();
+ }
+}
+
+void Model::RecordStop(const string& name, bool start_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_stop();
+ if (start_output && (*node)->output()) {
+ (*node)->output()->record_start();
+ }
+ }
+}
+
+void Model::RemoveNode(const string& name) {
+ mutex_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node && (*node)->output()) {
+ (*node)->output()->remove_input(*node);
+ }
+ lookup_table_.erase(name);
+}
+
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
output_->CollectTunables(&tunables);
return tunables;
}
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index f88ec06..26402f5 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -24,6 +24,7 @@
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@@ -32,294 +33,6 @@
namespace data {
namespace model {
-class Model;
-class Node;
-
-// Abstract representation of a TensorFlow input pipeline node. It collects
-// information about inputs to this node, processing time spent executing the
-// node logic, number of elements produced by the node, various other
-// information (e.g. batch size or execution parallelism).
-//
-// Developers of tf.data transformations are not expected to interact with this
-// class directly. Boiler plate code for creating the abstract representation of
-// the input pipeline and collecting common information has been added to the
-// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// In addition, `DatasetBaseIterator` provides wrappers that can be used for
-// transformation-specific information collection. The `SetMetadata` wrapper can
-// be used to pass arbitrary metadata to the modeling framework, while the
-// `StartWork` and `StopWork` wrappers should be used to correctly account for
-// processing time of multi-threaded transformation that yield the CPU; such
-// transformations should invoke `StartWork()` when a transformation thread
-// starts executing (e.g. when created or woken up) and `StopWork()` when a
-// transformation thread stops executing (e.g. when returning or waiting).
-//
-// TODO(jsimsa): Create an API to capture the abstract semantics of each
-// tf.data transformation and replace switch-case blocks with inheritance.
-class Node {
- public:
- Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
-
- // Adds a constant parameter.
- void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- constant_params_[name] = value;
- }
-
- // Records that the node produced an element.
- void add_element() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- num_elements_++;
- }
-
- // Adds an input.
- void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- inputs_.push_back(node);
- }
-
- // Increments the aggregate processing time by the given delta.
- void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- processing_time_ += delta;
- }
-
- // Adds a tunable parameter.
- void add_tunable_param(const string& name, int64 value, int64 min, int64 max,
- std::function<void(int64)>&& set_fn)
- LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- tunable_params_[name] =
- std::make_shared<Tunable>(value, min, max, std::move(set_fn));
- }
-
- // Returns the unique node ID.
- int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
-
- // Returns the node inputs.
- std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return inputs_;
- }
-
- // Returns the node name.
- const string& name() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return name_;
- }
-
- // Returns the number of elements produced by the node.
- int64 num_elements() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return num_elements_;
- }
-
- // Returns the node output.
- std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return output_;
- }
-
- // Removes an input.
- void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- inputs_.remove(input);
- }
-
- // Sets the node name.
- void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- name_ = name;
- type_ = TypeFromName(name);
- }
-
- // Set the node output.
- void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- output_ = output;
- }
-
- // Records that a node thread has started work.
- void start_work() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
- }
-
- // Records that a node thread has stopped work.
- void stop_work() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- auto iter = work_start_.find(std::this_thread::get_id());
- CHECK(work_start_.end() != iter)
- << "Encountered a stop event that was not preceded by a start event.";
- processing_time_ += Env::Default()->NowNanos() - iter->second;
- work_start_.erase(iter);
- }
-
- private:
- // Represents a tunable parameter.
- struct Tunable {
- Tunable(int64 value, int64 min, int64 max,
- std::function<void(int64)> set_fn)
- : value(value), min(min), max(max), set_fn(std::move(set_fn)) {}
-
- int64 value;
- int64 min;
- int64 max;
- std::function<void(int64)> set_fn;
- };
-
- enum class Type {
- BATCH = 0,
- CACHE,
- CONCATENATE,
- FILTER,
- FLAT_MAP,
- INTERLEAVE,
- MAP,
- MAP_AND_BATCH,
- PADDED_BATCH,
- PARALLEL_INTERLEAVE,
- PARALLEL_INTERLEAVE_V2,
- PARALLEL_MAP,
- PREFETCH,
- REPEAT,
- SHUFFLE,
- SKIP,
- TAKE,
- ZIP,
- UNKNOWN,
- };
-
- // Collects tunable parameters in the subtree rooted in this node.
- void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables)
- LOCKS_EXCLUDED(mu_);
-
- // Gets a value of the given parameter (tunable or constant).
- int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Returns the per-element processing time spent in this node.
- int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return NanosPerElementLocked();
- }
-
- int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (num_elements_ == 0) {
- return 0;
- }
- return (int64)((double)processing_time_ / (double)num_elements_);
- }
-
- // Returns the per-element output time for this node.
- int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return OutputTimeLocked(input_times);
- }
-
- int64 OutputTimeLocked(std::vector<int64>* input_times)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- int64 OutputTimeForInputs(std::vector<int64>* input_times)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 sum = 0;
- for (auto input : inputs_) {
- sum += input->OutputTime(input_times);
- }
- return sum;
- }
-
- // Returns the per-element processing time spent in the subtree rooted in this
- // node.
- int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return ProcessingTimeLocked();
- }
-
- int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Returns the per-element processing time spent in the inputs of this node.
- int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 sum = 0;
- for (auto input : inputs_) {
- sum += input->ProcessingTimeLocked();
- }
- return sum;
- }
-
- Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (name_ == "Batch") {
- return Type::BATCH;
- }
- if (str_util::EndsWith(name_, "Cache")) {
- return Type::CACHE;
- }
- if (name_ == "Concatenate") {
- return Type::CONCATENATE;
- }
- if (name_ == "Filter") {
- return Type::FILTER;
- }
- if (name_ == "FlatMap") {
- return Type::FLAT_MAP;
- }
- if (name_ == "Interleave") {
- return Type::INTERLEAVE;
- }
- if (name_ == "Map") {
- return Type::MAP;
- }
- if (name_ == "MapAndBatch") {
- return Type::MAP_AND_BATCH;
- }
- if (name_ == "PaddedBatch") {
- return Type::PADDED_BATCH;
- }
- if (name_ == "ParallelInterleave") {
- return Type::PARALLEL_INTERLEAVE;
- }
- if (name_ == "ParallelInterleaveV2") {
- return Type::PARALLEL_INTERLEAVE_V2;
- }
- if (name_ == "ParallelMap") {
- return Type::PARALLEL_MAP;
- }
- if (name_ == "Prefetch") {
- return Type::PREFETCH;
- }
- if (str_util::EndsWith(name_, "Repeat")) {
- return Type::REPEAT;
- }
- if (name_ == "Shuffle") {
- return Type::SHUFFLE;
- }
- if (str_util::EndsWith(name_, "Skip")) {
- return Type::SKIP;
- }
- if (str_util::EndsWith(name_, "Take")) {
- return Type::TAKE;
- }
- if (name_ == "Zip") {
- return Type::ZIP;
- }
- return Type::UNKNOWN;
- }
-
- mutex mu_;
- const int64 id_;
- Type type_ GUARDED_BY(mu_);
- string name_ GUARDED_BY(mu_);
- int64 processing_time_ GUARDED_BY(mu_) = 0;
- int64 num_elements_ GUARDED_BY(mu_) = 0;
- std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
- std::map<string, int64> constant_params_ GUARDED_BY(mu_);
- // Tunables are shared with the model during optimization.
- std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
- std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
- std::shared_ptr<Node> output_ GUARDED_BY(mu_);
-
- friend class Model;
-};
-
// Abstract representation of a TensorFlow input pipeline that can be used
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
@@ -334,39 +47,351 @@
public:
Model() = default;
- // Returns the model output node.
- std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return output_;
- }
-
- // Adds a node with the given name and given output (identified by name).
- std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
+ // Adds a constant parameter for the given node.
+ void AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value)
LOCKS_EXCLUDED(mu_);
- // Looks up the node using the given name.
- std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+ // Adds a node with the given name and given output (identified by name).
+ void AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
+
+ // Increments the processing time for the given node..
+ void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);
+
+ // Adds a tunable parameter for the given node.
+ void AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
// Runs optimization.
void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
- // Removes the node identified by the given name.
- void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+ // Records that a node has produced an element.
+ void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has started work. If `stop_output` is set, it
+ // also records that the output of the given node has stopped work.
+ void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has stopped work. If `stop_output` is set, it
+ // also records that the output of the given node has started work.
+ void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);
+
+ // Removes the given node.
+ void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
private:
+ // Abstract representation of a TensorFlow input pipeline node. It collects
+ // information about inputs to this node, processing time spent executing the
+ // node logic, number of elements produced by the node, various other
+ // information (e.g. batch size or execution parallelism).
+ //
+ // Developers of tf.data transformations are not expected to interact with
+ // this class directly. Boiler plate code for creating the abstract
+ // representation of the input pipeline and collecting common information has
+ // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
+ // respectively.
+ //
+ // In addition, `DatasetBaseIterator` provides wrappers that can be used for
+ // transformation-specific information collection. The `SetMetadata` wrapper
+ // can be used to pass arbitrary metadata to the modeling framework, while the
+ // `StartWork` and `StopWork` wrappers should be used to correctly account for
+ // processing time of multi-threaded transformation that yield the CPU; such
+ // transformations should invoke `StartWork()` when a transformation thread
+ // starts executing (e.g. when created or woken up) and `StopWork()` when a
+ // transformation thread stops executing (e.g. when returning or waiting).
+ //
+ // TODO(jsimsa): Create an API to capture the abstract semantics of each
+ // tf.data transformation and replace switch-case blocks with inheritance.
+ class Node {
+ public:
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var)
+ : value(*value),
+ min(min),
+ max(max),
+ value_ptr(value),
+ cond_var(cond_var) {}
+
+ // Identifies the model value of the parameter. This can be different from
+ // the actual value (e.g. during optimization search).
+ int64 value;
+
+ // Identifies the minimum value of the parameter.
+ int64 min;
+
+ // Identifies the maximum value of the parameter.
+ int64 max;
+
+ // Points to the actual value of the parameter. Not owned.
+ std::atomic<int64>* value_ptr;
+
+ // If non-null, this condition variable is notified when the model updates
+ // the actual value of the parameter (via `value_ptr`). Not owned.
+ condition_variable* cond_var;
+ };
+
+ Node(int64 id, const string& name, std::shared_ptr<Node> output)
+ : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}
+
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
+ }
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
+ }
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name, std::atomic<int64>* value,
+ int64 min, int64 max, condition_variable* cond_var)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(value, min, max, cond_var);
+ }
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
+ }
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
+ }
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
+ }
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Records that the node produced an element.
+ void record_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
+ }
+
+ // Records that a node thread has started executing.
+ void record_start() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+ }
+
+ // Records that a node thread has stopped executing.
+ void record_stop() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ std::thread::id tid = std::this_thread::get_id();
+ auto start_time = gtl::FindOrNull(work_start_, tid);
+ DCHECK(start_time)
+ << "Encountered a stop event that was not preceded by a start event.";
+ if (start_time) {
+ processing_time_ += Env::Default()->NowNanos() - *start_time;
+ work_start_.erase(tid);
+ }
+ }
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
+ }
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
+
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
+
+ // Returns the per-element processing time spent in the subtree rooted in
+ // this node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
+
+ private:
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return NanosPerElementLocked();
+ }
+
+ int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
+
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
+
+ int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTime();
+ }
+ return sum;
+ }
+
+ Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
+
+ mutex mu_;
+ const int64 id_;
+ const string name_;
+ const Type type_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ };
+
std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ SHARED_LOCKS_REQUIRED(mu_);
- int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);
- int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);
- // Used for coordination between different input pipeline threads.
+ // Used for coordination between different input pipeline threads. Exclusive
+ // access is required only when adding or removing nodes. Concurrent access to
+ // existing nodes is protected by a node mutex.
mutex mu_;
- // Used for preventing iterator deletion when optimization is in progress
- // because the optimization may try to update the values of tunable
- // parameters.
- mutex optimization_mu_ ACQUIRED_BEFORE(mu_);
int64 id_counter_ GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 42ec315..43ac1d0 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -372,6 +372,14 @@
node_def.name());
}
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ return Status::OK();
+}
+
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type) {
DataTypeVector output_types;
@@ -397,9 +405,7 @@
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
- for (const auto& arg : op_def.input_arg()) {
- TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
- }
+ TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
return OutputTypesForNode(node_def, op_def, outputs);
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 7528d3d..187bfa2 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -249,6 +249,10 @@
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594..d9d4370 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 516afa5..eb9c79f 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -948,9 +948,69 @@
}
}
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+ if (dim_index == num_dims - 1) {
+ strings::StrAppend(result, " ");
+ return;
+ }
+ for (int j = 0; j < num_dims - dim_index - 1; j++) {
+ strings::StrAppend(result, "\n");
+ }
+ for (int j = 0; j <= dim_index; j++) {
+ strings::StrAppend(result, " ");
+ }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+ int64 num_elts_at_ends, int num_dims, const T* data,
+ int64 data_index, string* result) {
+ // We have recursed beyond all the dimensions into a single element
+ // of the tensor.
+ if (dim_index == num_dims) {
+ strings::StrAppend(result, PrintOneElement(data[data_index]));
+ return;
+ }
+
+ strings::StrAppend(result, "[");
+ int64 element_count = shape[dim_index];
+ int64 start_of_end =
+ std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+ // Loop every element of one dim.
+ int64 elements_per_iter = 1;
+ for (int i = dim_index + 1; i < num_dims; i++) {
+ elements_per_iter *= shape[i];
+ }
+ for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+ if (i > 0) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ }
+
+ // As for each element, print the sub-dim.
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+ if (element_count > 2 * num_elts_at_ends) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ strings::StrAppend(result, "...");
+ }
+ for (int64 i = start_of_end; i < element_count; i++) {
+ // As for each element, print the sub-dim.
+ PrintDimSpacing(dim_index, num_dims, result);
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+
+ strings::StrAppend(result, "]");
+}
+
template <typename T>
string SummarizeArray(int64 limit, int64 num_elts,
- const TensorShape& tensor_shape, const char* data) {
+ const TensorShape& tensor_shape, const char* data,
+ const bool print_v2) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
@@ -963,17 +1023,26 @@
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
- int64 data_index = 0;
- const int shape_size = tensor_shape.dims();
- PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+ if (print_v2) {
+ const int num_dims = tensor_shape.dims();
+ PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+ } else {
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
- if (num_elts > limit) strings::StrAppend(&ret, "...");
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ }
+
return ret;
}
} // namespace
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
const int64 num_elts = NumElements();
+ if (max_entries < 0) {
+ max_entries = num_elts;
+ }
size_t limit = std::min(max_entries, num_elts);
if ((limit > 0) && (buf_ == nullptr)) {
return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1051,54 @@
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+ print_v2);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, shape_, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, shape_, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT32:
- return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, shape_, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, shape_, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, shape_, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT64:
- return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, shape_, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, shape_, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases
string ret;
+ if (print_v2) {
+ strings::StrAppend(&ret, "[");
+ }
// TODO(irving): Don't call flat every time around this
// loop.
for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1118,9 @@
}
}
if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+ if (print_v2) {
+ strings::StrAppend(&ret, "]");
+ }
return ret;
}
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 696fd27..e412329 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -154,7 +154,7 @@
/// Returns the estimated memory usage of this tensor.
size_t TotalBytes() const;
- // Returns the size of sallocated memory for this tensor.
+ // Returns the size of allocated memory for this tensor.
size_t AllocatedBytes() const;
/// Returns true iff this tensor is aligned.
@@ -430,7 +430,7 @@
int64 begin) const;
/// Render the first `max_entries` values in `*this` into a string.
- string SummarizeValue(int64 max_entries) const;
+ string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
/// A human-readable summary of the tensor suitable for debugging.
string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 9a78cdc..fc05c86 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1295,6 +1295,63 @@
EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
}
+TEST(SummarizeValue, INT32_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(10, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+ Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+ Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+ Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+ x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
void BM_CreateAndDestroy(int iters) {
TensorShape shape({10, 20});
while (--iters) {
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index bd0284d..b00196f 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -32,7 +32,7 @@
namespace graph {
// Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
void ToGraphDef(Graph* g, GraphDef* def);
// A few helpers to construct a graph.
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index b97603c..e4f6bf7 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -93,13 +93,13 @@
strings::StrCat("Not able to parse GPU device name: ", dev.name()));
}
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
return errors::Unavailable("Unknown TF GPU device with id ",
tf_gpu_id.value(), ": ", s.ToString());
}
- attr = GetLocalGPUInfo(cuda_gpu_id);
+ attr = GetLocalGPUInfo(platform_gpu_id);
} else if (dev.device_type().find("XLA") == string::npos) {
// Filter out the fake XLA devices to avoid double counting the actual
// hardware resources that are available.
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index a751972..567e7c0 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -70,13 +70,14 @@
return device;
}
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) {
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
DeviceProperties device;
device.set_type("GPU");
#if GOOGLE_CUDA
cudaDeviceProp properties;
- cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value());
+ cudaError_t error =
+ cudaGetDeviceProperties(&properties, platform_gpu_id.value());
if (error != cudaSuccess) {
device.set_type("UNKNOWN");
LOG(ERROR) << "Failed to get device properties, error code: " << error;
@@ -122,15 +123,15 @@
} else if (device.type == "GPU") {
if (device.has_id) {
TfGpuId tf_gpu_id(device.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
LOG(ERROR) << s;
return unknown;
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else {
- return GetLocalGPUInfo(CudaGpuId(0));
+ return GetLocalGPUInfo(PlatformGpuId(0));
}
}
return unknown;
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
index ca15c48..f0a342b 100644
--- a/tensorflow/core/grappler/clusters/utils.h
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -28,7 +28,7 @@
// Returns the DeviceProperties for the specified GPU attached to the server on
// which grappler is running.
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id);
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id);
// Returns the DeviceProperties of the specified device
DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc
index 74218ad..3863d62 100644
--- a/tensorflow/core/grappler/clusters/utils_test.cc
+++ b/tensorflow/core/grappler/clusters/utils_test.cc
@@ -31,22 +31,22 @@
LOG(INFO) << "CUDA is enabled.";
DeviceProperties properties;
- // Invalid CUDA GPU ID.
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ // Invalid platform GPU ID.
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("UNKNOWN", properties.type());
- // Succeed when a valid CUDA GPU id was inserted.
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ // Succeed when a valid platform GPU id was inserted.
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
EXPECT_EQ("NVIDIA", properties.vendor());
#else
LOG(INFO) << "CUDA is not enabled.";
DeviceProperties properties;
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("GPU", properties.type());
#endif
}
@@ -74,20 +74,20 @@
EXPECT_EQ("NVIDIA", properties.vendor());
#endif
- // TF to CUDA GPU id mapping entry doesn't exist.
+ // TF to platform GPU id mapping entry doesn't exist.
device.has_id = true;
device.id = 0;
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
#if GOOGLE_CUDA
- // Invalid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100));
+ // Invalid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100));
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
- // Valid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0));
+ // Valid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(1), PlatformGpuId(0));
device.id = 1;
properties = GetDeviceInfo(device);
EXPECT_EQ("GPU", properties.type());
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 83434ea..5415324b 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -209,13 +209,13 @@
if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
if (parsed.type == "GPU") {
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
// We are probably running simulation without linking cuda libraries.
- cuda_gpu_id = CudaGpuId(parsed.id);
+ platform_gpu_id = PlatformGpuId(parsed.id);
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else if (parsed.type == "CPU") {
return GetLocalCPUInfo();
}
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 0292052..261dee4 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -8,10 +8,6 @@
# Platform specific build config
load(
- "//tensorflow/core:platform/default/build_config.bzl",
- "tf_protos_grappler",
-)
-load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
@@ -97,7 +93,6 @@
deps = [
":evaluation_utils",
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -107,6 +102,7 @@
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -261,7 +257,6 @@
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -270,6 +265,7 @@
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -648,7 +644,6 @@
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -658,6 +653,7 @@
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -715,31 +711,6 @@
)
cc_library(
- name = "symbolic_shapes",
- srcs = ["symbolic_shapes.cc"],
- hdrs = ["symbolic_shapes.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ] + tf_protos_grappler(),
-)
-
-tf_cc_test(
- name = "symbolic_shapes_test",
- srcs = ["symbolic_shapes_test.cc"],
- deps = [
- ":symbolic_shapes",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "debug_stripper",
srcs = ["debug_stripper.cc"],
hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 992e85d..76a9dca 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -35,8 +35,8 @@
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -2367,26 +2367,24 @@
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
- for (int i = 0; i < p.shape().dim_size(); ++i) {
- if (p.shape().dim(i).size() < 0) {
+ const auto& pow_props =
+ ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
+ if (pow_props.shape().dim(i).size() < 0) {
// skip if p is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor pow(p.dtype(), p.shape());
- if (!pow.FromProto(p.value())) {
+ if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
+ Tensor pow(pow_props.dtype(), pow_props.shape());
+ if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
+ pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
+ if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
@@ -2399,12 +2397,19 @@
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ const auto& value_props =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ const TensorShapeProto& output_shape =
+ ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(1, 0)) {
+ } else if (curr == complex128(1, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ // Pow could be used to broadcast, so make sure the shapes of the two
+ // arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
@@ -2414,20 +2419,20 @@
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(0, 0)) {
- const auto& b =
- ctx().graph_properties->GetInputProperties(node->name())[0];
- for (int i = 0; i < b.shape().dim_size(); ++i) {
- if (b.shape().dim(i).size() < 0) {
+ } else if (curr == complex128(0, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ for (int i = 0; i < value_props.shape().dim_size(); ++i) {
+ if (value_props.shape().dim(i).size() < 0) {
// skip if b is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(b.shape()) && b.has_value()) {
- Tensor base(b.dtype(), b.shape());
- if (!base.FromProto(b.value())) {
+ if (TensorShape::IsValid(value_props.shape()) &&
+ value_props.has_value()) {
+ Tensor base(value_props.dtype(), value_props.shape());
+ if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
+ value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
@@ -2585,12 +2590,10 @@
~ConvertExpm1Stage() override = default;
bool IsSupported(const NodeDef* node) const override {
- if (!IsSub(*node))
- return false;
+ if (!IsSub(*node)) return false;
NodeDef* input;
- if (!GetInputNode(node->input(0), &input).ok())
- return false;
+ if (!GetInputNode(node->input(0), &input).ok()) return false;
return IsExp(*input);
}
@@ -2610,10 +2613,8 @@
return Status::OK();
}
- const auto& t =
- ctx().graph_properties->GetInputProperties(exp->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(node->name())[1];
+ const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
+ const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 88839d9..77f3c64 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2474,6 +2474,9 @@
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+ auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+ auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2481,21 +2484,24 @@
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
+ Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+ Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
+ "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
+ EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
+ EXPECT_EQ(9, tensors.size());
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@@ -2509,6 +2515,9 @@
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
+ AddNode("z", "Const", {}, {}, &want);
+ AddNode("ones", "Const", {}, {}, &want);
+ AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2517,6 +2526,8 @@
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
+ AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+ AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 99737a7..cfbd298 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -32,8 +32,8 @@
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -437,25 +437,6 @@
}
namespace {
-bool ShapesEqual(const TensorShapeProto& shape1,
- const TensorShapeProto& shape2) {
- if (shape1.unknown_rank() || shape2.unknown_rank()) {
- return false;
- }
- if (shape1.dim_size() != shape2.dim_size()) {
- return false;
- }
- for (int i = 0; i < shape1.dim_size(); ++i) {
- if (shape1.dim(i).size() != shape2.dim(i).size()) {
- return false;
- }
- if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
- return false;
- }
- }
- return true;
-}
-
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -2348,7 +2329,8 @@
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
- const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ const bool y_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
@@ -2378,7 +2360,8 @@
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
- const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ const bool x_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index e84df10..79d5fe8 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -49,6 +49,7 @@
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":function_utils",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -67,6 +68,7 @@
srcs = ["fusion_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":fusion_utils",
":graph_utils",
"//tensorflow/core:framework",
@@ -78,6 +80,40 @@
)
cc_library(
+ name = "function_utils",
+ srcs = ["function_utils.cc"],
+ hdrs = [
+ "function_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_utils_test",
+ srcs = ["function_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -137,7 +173,9 @@
],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":graph_utils",
+ ":vectorization_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:grappler_item",
@@ -409,3 +447,42 @@
"//tensorflow/core/grappler:grappler_item",
],
)
+
+cc_library(
+ name = "vectorization_utils",
+ srcs = ["vectorization_utils.cc"],
+ hdrs = [
+ "vectorization_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/utils:functions",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "vectorization_utils_test",
+ srcs = ["vectorization_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000..e95ea1a
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+template <typename Predicate, typename Collection>
+std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ std::vector<int> indices = {};
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ indices.push_back(idx);
+ }
+ idx++;
+ }
+ return indices;
+}
+
+} // namespace
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+ const string& output, int position)
+ : node_name(node_name), node_output(output), position(position) {
+ full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+ // Parses node_name:node_output:position string into its components.
+ full_str = input;
+ StringPiece capture;
+ StringPiece remaining;
+
+ // Parse "node_name"
+ if (strings::Scanner(input)
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_name = string(capture.data(), capture.size());
+ }
+
+ // Parse "node_output" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER)
+ .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_output = string(capture.data(), capture.size());
+ }
+
+ // Parse "position" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .Many(strings::Scanner::DIGIT)
+ .GetResult(nullptr, &capture)) {
+ CHECK(strings::safe_strto32(capture, &position));
+ }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+ FunctionDef* func) {
+ for (NodeDef& n : *func->mutable_node_def()) {
+ std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+ to);
+ }
+
+ for (auto& p : *func->mutable_ret()) {
+ if (p.second == from) {
+ p.second = to;
+ }
+ }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt) {
+ string name = string(prefix);
+ int id = function->signature().output_arg_size();
+ while (ContainsFunctionOutputWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+ output->set_name(name);
+ output->set_type(dt);
+
+ (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().input_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().output_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; },
+ function.node_def());
+
+ return indices.empty() ? -1 : indices.front();
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = string(prefix);
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000..d4ce824
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+ FunctionDefTensorDesc() = default;
+
+ FunctionDefTensorDesc(const string& node_name, const string& output,
+ int position);
+
+ // Parses node_name:node_output:position string into its components.
+ explicit FunctionDefTensorDesc(const string& input);
+
+ // TODO(rachelim): Add provisions to deal with special formats, like how
+ // GrapplerFunctionItem expands node output range if position is not defined
+ string full_str;
+ string node_name;
+ string node_output;
+ int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node);
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000..3739e20
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+ FunctionDefTensorDesc f("Cast:y:0");
+ EXPECT_EQ(f.full_str, "Cast:y:0");
+ EXPECT_EQ(f.node_name, "Cast");
+ EXPECT_EQ(f.node_output, "y");
+ EXPECT_EQ(f.position, 0);
+
+ FunctionDefTensorDesc f2("Arg0");
+ EXPECT_EQ(f2.full_str, "Arg0");
+ EXPECT_EQ(f2.node_name, "Arg0");
+ EXPECT_EQ(f2.node_output, "");
+ EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+ {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+ NodeDef* derive_node =
+ AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+ // Check that both the input to "X" and retval of "outer" are replaced.
+ ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+ EXPECT_EQ(outer.ret().at("out"), "arg0");
+ EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+ FunctionDef function = test::function::XTimesTwo();
+ AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+ EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+ function));
+ EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+ EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+ EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+ EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+} // namespace
+} // namespace function_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index 01a78c0..b3bfee1 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -22,6 +22,7 @@
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -407,7 +408,7 @@
auto* if_node = fused_function->add_node_def();
// This is guaranteed to succeed.
TF_CHECK_OK(if_builder.Finalize(if_node));
- graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+ function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index d5c6466..e667aff 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -110,9 +111,9 @@
CheckUniqueNames(*fused_function);
ASSERT_TRUE(
- graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
const auto &equal_node = fused_function->node_def(
- graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
fused_function->signature().output_arg(0).name());
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index d4ab444..b3f60e3 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -108,26 +108,6 @@
return graph->AddNode(std::move(node));
}
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd) {
- NodeDef* node = fd->add_node_def();
- if (!name.empty()) {
- node->set_name(string(name));
- } else {
- SetUniqueFunctionNodeName(op, fd, node);
- }
- node->set_op(string(op));
- for (const string& input : inputs) {
- node->add_input(input);
- }
- for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
- }
- return node;
-}
-
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@@ -196,6 +176,11 @@
return true;
}
+bool ContainsGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return FindGraphFunctionWithName(name, library) != -1;
+}
+
bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return FindGraphNodeWithName(name, graph) != -1;
}
@@ -204,18 +189,14 @@
return FindGraphNodeWithOp(op, graph) != -1;
}
-bool ContainsGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- return FindGraphFunctionWithName(name, library) != -1;
-}
-
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function) {
- return FindFunctionNodeWithName(name, function) != -1;
-}
-
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
+ return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
@@ -237,31 +218,6 @@
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
-int FindGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const FunctionDef& function) {
- return function.signature().name() == name;
- },
- library.function());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const NodeDef& node) { return node.name() == name; },
- function.node_def());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&op](const NodeDef& node) { return node.op() == op; },
- function.node_def());
-
- return indices.empty() ? -1 : indices.front();
-}
-
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
@@ -284,17 +240,6 @@
node->set_name(std::move(name));
}
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node) {
- string name = string(prefix);
- int id = function->node_def_size();
- while (ContainsFunctionNodeWithName(name, *function)) {
- name = strings::StrCat(prefix, "/_", id);
- ++id;
- }
- node->set_name(std::move(name));
-}
-
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
string name = string(prefix);
@@ -305,7 +250,6 @@
}
function->mutable_signature()->set_name(std::move(name));
}
-
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 6f431c2..1652afc 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,12 +37,6 @@
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
-// Adds a node to a FunctionDef.
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd);
-
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@@ -76,13 +70,6 @@
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -95,14 +82,6 @@
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -119,11 +98,6 @@
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node);
-
// Sets the node name using the `prefix` name as a prefix while guaranteeing the
// name is unique across the graph.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index c19ac7b..6877c20 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -112,20 +112,6 @@
ContainsGraphFunctionWithName(new_function->signature().name(), library));
}
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithName(
- "weird_name_that_should_not_be_there", function));
- EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
- function));
- EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -150,22 +136,6 @@
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
}
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FunctionDefLibrary library;
EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -225,21 +195,6 @@
EXPECT_NE(node2->name(), node3->name());
}
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
- FunctionDef function = test::function::XTimesTwo();
- NodeDef node;
- SetUniqueFunctionNodeName("abc", &function, &node);
- for (const NodeDef& function_node : function.node_def()) {
- EXPECT_NE(node.name(), function_node.name());
- }
- auto* new_node = function.add_node_def();
- *new_node = node;
-
- NodeDef other;
- SetUniqueFunctionNodeName("abc", &function, &other);
- EXPECT_NE(other.name(), new_node->name());
-}
-
TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
FunctionDefLibrary library;
FunctionDef* new_function = library.add_function();
@@ -251,43 +206,6 @@
other_function->signature().name());
}
-TEST(GraphUtilsTest, AddNodeToFunctionDef) {
- FunctionDef func;
- const char* op_name = "xxx";
- AddNode(op_name, op_name, {}, {}, &func);
-
- const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
- EXPECT_EQ(node1.op(), op_name);
- EXPECT_EQ(node1.input_size(), 0);
- EXPECT_EQ(node1.attr_size(), 0);
-
- const std::vector<string> inputs({"input1", "input2"});
- AddNode("", op_name, inputs, {}, &func);
- const NodeDef& node2 =
- func.node_def(FindFunctionNodeWithName("xxx/_2", func));
- EXPECT_EQ(node2.op(), op_name);
- EXPECT_EQ(node2.attr_size(), 0);
- EXPECT_EQ(node2.input_size(), inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- EXPECT_EQ(node2.input(i), inputs[i]);
- }
-
- AttrValue a1, a2;
- a1.set_type(DT_INT32);
- a2.set_type(DT_INT64);
- const std::vector<std::pair<string, AttrValue>> attrs(
- {{"attr1", a1}, {"attr2", a2}});
- AddNode("", op_name, {}, attrs, &func);
- const NodeDef& node3 =
- func.node_def(FindFunctionNodeWithName("xxx/_3", func));
- EXPECT_EQ(node3.op(), op_name);
- EXPECT_EQ(node3.input_size(), 0);
- EXPECT_EQ(node3.attr_size(), attrs.size());
- for (size_t i = 0; i < attrs.size(); ++i) {
- EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
- }
-}
-
TEST(GraphUtilsTest, GetInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a019b77..7a2f191 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -24,6 +25,7 @@
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -37,11 +39,11 @@
(*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
}
-FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+// Returns a FunctionDef containing a MapDefun op that wraps the original
+// function.
+FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
const FunctionDef& orig_func,
FunctionDefLibrary* library) {
- // If we decide to use a different method of vectorization, we can just
- // swap out this part.
FunctionDef* vectorized_func = library->add_function();
// Function inputs and outputs are the same as original, just
// with different shapes.
@@ -52,8 +54,8 @@
// Add MapDefun node
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
map_defun_node->set_op("MapDefun");
- graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
- map_defun_node);
+ function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+ vectorized_func, map_defun_node);
// Set attrs and inputs
for (const string& k : {"f", "output_types", "output_shapes"}) {
@@ -81,6 +83,30 @@
return vectorized_func;
}
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing
+ // efficient vectorization with VectorizeMapDefun.
+ FunctionDef* vectorized_func =
+ CreateMapDefunWrapper(map_node, orig_func, library);
+ NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
+ DCHECK_EQ(map_defun_node->op(), "MapDefun");
+
+ // Create a copy of the original function so that we can mutate it, and
+ // attach that to the map defun node.
+ FunctionDef* map_defun_fn = library->add_function();
+ *map_defun_fn = orig_func;
+ graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
+ map_defun_fn);
+ (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
+ map_defun_fn->signature().name());
+
+ vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
+ map_defun_node);
+ return vectorized_func;
+}
+
bool IsOutputShapesFullyDefined(const NodeDef& node) {
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
if (shapes_attr == nullptr) return false;
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000..bfca63b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+using function_utils::FunctionDefTensorDesc;
+
+namespace {
+
+void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
+ const string& output_retval, const DataType t) {
+ // Set to unknown shape
+ TensorShapeProto tensor_shape_proto;
+ PartialTensorShape().AsProto(&tensor_shape_proto);
+
+ function_utils::AddFunctionOutputWithUniqueName(
+ "vectorized_out", output_retval, map_defun_fn, t);
+
+ *(*map_defun_node->mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape() = tensor_shape_proto;
+ (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+}
+
+void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, int output_position) {
+ DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs =
+ map_defun_fn->signature().output_arg_size() - output_position - 1;
+
+ // Remove from map_defun_fn's ret dict and output args
+ map_defun_fn->mutable_ret()->erase(
+ map_defun_fn->signature().output_arg(output_position).name());
+ map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
+ output_position, 1);
+
+ // Renumber outputs that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i + 1),
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i),
+ outer_scope);
+ }
+ map_defun_node->mutable_attr()
+ ->at("output_shapes")
+ .mutable_list()
+ ->mutable_shape()
+ ->DeleteSubrange(output_position, 1);
+ map_defun_node->mutable_attr()
+ ->at("output_types")
+ .mutable_list()
+ ->mutable_type()
+ ->ExtractSubrange(output_position, 1, nullptr);
+}
+
+Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+ const NodeDef& cast_node,
+ std::map<string, string>* conversion_map) {
+ if (inputs.size() != 1) {
+ return errors::Internal("Cast op should only have one input.");
+ }
+
+ // Add new Cast node
+ NodeDef* new_cast_node = outer_scope->add_node_def();
+ *new_cast_node = cast_node;
+ new_cast_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", cast_node.name()), outer_scope,
+ new_cast_node);
+ new_cast_node->set_input(0, inputs[0]);
+
+ // Add the output mapping to conversion map
+ (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] =
+ strings::StrCat(new_cast_node->name(), ":y:0");
+
+ return Status::OK();
+}
+
+Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+ const NodeDef& unpack_node,
+ std::map<string, string>* conversion_map) {
+ if (inputs.size() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
+ }
+
+ // Add new Unpack node
+ NodeDef* new_unpack_node = outer_scope->add_node_def();
+ *new_unpack_node = unpack_node;
+ new_unpack_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", unpack_node.name()), outer_scope,
+ new_unpack_node);
+
+ // Increment "axis" attr by 1:
+ (*new_unpack_node->mutable_attr())["axis"].set_i(
+ unpack_node.attr().at("axis").i() + 1);
+ new_unpack_node->set_input(0, inputs[0]);
+
+ // Add the output mappings to conversion map
+ int num = new_unpack_node->attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] =
+ strings::StrCat(new_unpack_node->name(), ":output:", i);
+ }
+
+ return Status::OK();
+}
+
+int FindOutputToConvert(const FunctionDef& function,
+ const std::set<string>& unconvertible,
+ FunctionDefTensorDesc* f) {
+ for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
+ const string& ret_key = function.signature().output_arg(i).name();
+ *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+
+ if (unconvertible.find(f->node_name) == unconvertible.end()) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// Each instance of the class encapsulates all the data necessary to vectorize a
+// MapDefun op in place.
+class Vectorization {
+ public:
+ Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node)
+ : outer_scope_(outer_scope),
+ map_defun_fn_(map_defun_fn),
+ map_defun_node_(map_defun_node) {}
+
+ // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
+ // the outer_scope_, until there are no convertible outputs remaining.
+ // This method is idempotent.
+ void Vectorize();
+
+ private:
+ // Vectorizes the map defun function's output at output_position
+ Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
+ // Given a descriptor of the original output tensor, gets a string
+ // corresponding to the converted output tensor.
+ Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
+ string* converted);
+ Status AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc);
+
+ // Adds mappings from node's outputs tensors to converted output tensors,
+ // creating the necessary new node(s). Generally, the steps to convert an op
+ // are:
+ // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
+ // and modify map_defun_node_ attrs accordingly
+ // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // These operations collectively compute the same value as what running
+ // the original operation on slices of the input tensors would produce.
+ // For example, a Cast op in MapDefun translates to a Cast op in
+ // outer_scope_, since the vectorized version of Cast is itself.
+ // 3) Set inputs of new node(s) to the corresponding converted inputs (that
+ // are now outputs of map_defun_node_)
+ // 4) For each output of the old node, add the mapping of output strings to
+ // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
+ Status AddConversionMappingFromOp(const NodeDef& node,
+ const FunctionDefTensorDesc& output_desc);
+
+ // Maps a tensor name to the name of the corresponding vectorized tensor. For
+ // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
+ std::map<string, string> conversion_map_;
+ // Unconvertible node names
+ std::set<string> unconvertible_;
+
+ FunctionDef* outer_scope_;
+ FunctionDef* map_defun_fn_;
+ NodeDef* map_defun_node_;
+};
+
+Status Vectorization::AddConversionMappingFromOp(
+ const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
+ for (const string& input_name : node.input()) {
+ if (IsControlInput(input_name)) {
+ return errors::InvalidArgument(
+ "Vectorizing outputs with control inputs is currently not "
+ "supported.");
+ }
+ }
+
+ // TODO(rachelim): Have some mechanism for registering converters and some
+ // uniform, simpler way to represent them.
+
+ DataTypeVector types;
+ const OpDef* op_def = nullptr;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
+ TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+
+ std::vector<string> promoted_inputs;
+ promoted_inputs.reserve(node.input_size());
+ for (int i = 0; i < node.input_size(); ++i) {
+ promoted_inputs.push_back(strings::StrCat(
+ map_defun_node_->name(),
+ ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ }
+
+ if (node.op() == "Cast") {
+ TF_RETURN_IF_ERROR(
+ ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_));
+ } else if (node.op() == "Unpack") {
+ TF_RETURN_IF_ERROR(
+ ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_));
+ } else {
+ return errors::Unimplemented("Op converter for \"", node.op(),
+ "\" not implemented yet");
+ }
+
+ // If we get here, the conversion was successful, so we promote the inputs
+ // of the ops to MapDefun outputs.
+ for (int i = 0; i < types.size(); ++i) {
+ AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ }
+
+ return Status::OK();
+}
+
+Status Vectorization::AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc) {
+ int input_index = function_utils::FindFunctionInputWithName(
+ output_desc.node_name, *map_defun_fn_);
+ if (input_index == -1) {
+ return errors::Internal("Cannot convert non-existent input.");
+ }
+
+ conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutputHelper(
+ const FunctionDefTensorDesc& output_desc, string* converted) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
+ *converted = *found;
+ return Status::OK();
+ }
+
+ int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
+ *map_defun_fn_);
+ if (index == -1) { // The output comes from an input
+ TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
+ map_defun_fn_->node_def(index), output_desc));
+ }
+ *converted = conversion_map_.at(output_desc.full_str);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position,
+ const FunctionDefTensorDesc& output_desc) {
+ string converted_output_name;
+ TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+
+ // Remove the old output and make everything that referenced it point
+ // to the new string
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node_->name(), ":output:", output_position),
+ converted_output_name, outer_scope_);
+ RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
+ output_position);
+
+ return Status::OK();
+}
+
+void Vectorization::Vectorize() {
+ while (true) {
+ FunctionDefTensorDesc desc;
+ int output_position =
+ FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
+ if (output_position == -1) break;
+
+ if (!ConvertOutput(output_position, desc).ok()) {
+ unconvertible_.insert(desc.node_name);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->signature().output_arg_size() == 0) {
+ outer_scope_->mutable_node_def()->DeleteSubrange(
+ function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
+ *outer_scope_),
+ 1);
+ }
+
+ if (!unconvertible_.empty()) {
+ VLOG(2) << "The following nodes could not be converted: ["
+ << absl::StrJoin(unconvertible_, ", ") << "].";
+ }
+}
+} // namespace
+
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node) {
+ Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+}
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000..bb405fa
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a function, `map_defun_fn`, that is mapped across some input vector
+// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
+// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
+// `outer_scope`; that is, replacing `map_defun_fn` operations with new
+// `outer_scope` operations that produce the same vector output(s) as executing
+// the `map_defun_fn` operations on elements of vector input(s) would. If all
+// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
+// eliminated from `outer_scope` altogether. However, if some operations cannot
+// be lifted, and this vectorization only succeeds partially, `map_defun_node`
+// remains to be used for operations that were not lifted.
+//
+// Example:
+// If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to the `outer_scope`.
+//
+// Before:
+//
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | map_defun_fn +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node);
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000..e129fa9
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,600 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+ DataType src, DataType dst, bool truncate,
+ FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("SrcT", src, node);
+ graph_transforms::SetNodeAttr("DstT", dst, node);
+ graph_transforms::SetNodeAttr("Truncate", truncate, node);
+ return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+ DataType t, int axis, int num, FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("T", t, node);
+ graph_transforms::SetNodeAttr("axis", axis, node);
+ graph_transforms::SetNodeAttr("num", num, node);
+ return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+ const std::vector<DataType>& t_arguments,
+ const std::vector<DataType>& output_types,
+ const std::vector<TensorShape>& output_shapes,
+ const string& function_name, FunctionDef* fn) {
+ NameAttrList func;
+ func.set_name(function_name);
+ NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("output_types", output_types, node);
+ graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+ graph_transforms::SetNodeAttr("f", func, node);
+ return node;
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+ StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+ const std::vector<std::pair<string, DataType>>& outputs,
+ const std::map<string, string>& rets) {
+ FunctionDef func;
+ auto* signature = func.mutable_signature();
+ signature->set_name(string(name));
+ for (const auto& x : inputs) {
+ auto* arg_def = signature->add_input_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : outputs) {
+ auto* arg_def = signature->add_output_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : rets) {
+ (*func.mutable_ret())[x.first] = x.second;
+ }
+
+ return func;
+}
+
+TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | | | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "arg0"}, {"ret1", "arg1"}});
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
+ EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | +------+ | +---v--+ | |
+// | | |Const | | | Op0 | | |
+// | | +---v--+ | +---+--+ | |
+// | | | | | | |
+// | | | +---v--+ +---v--+ | |
+// | | +---| XOp1 | | XOp2 | | |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+// where XOp1 and XOp2 are not convertible.
+//
+// After:
+//
+// No change because the ops are not convertible.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ NodeDef* x_op1 =
+ function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(x_op1);
+
+ NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
+ CHECK_NOTNULL(x_op2);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +---------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +----------+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +----------+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+ // Tests that behavior is correct when an output is used more than once.
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+ {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), "x");
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | +---+--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---+--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 2);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | +---------+ | |
+// | | +---v--+ | | |
+// | | |Print | | | |
+// | | +---+--+ | | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ // The attrs aren't relevant
+ NodeDef* print_op =
+ function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(print_op);
+ NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+ false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+// same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+// broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+// they work.
+
+} // namespace
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 1ed1b22..4b0cbfa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -352,7 +352,7 @@
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- LOG(INFO) << "Starting optimization for grappler item: " << item.id;
+ VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 03e36a7..008a289 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -218,7 +218,7 @@
void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
const GraphDef& /*optimized_graph*/,
double /*result*/) {
- // Nothing to do for ArithmeticOptimizer.
+ // Nothing to do for RemapperOptimizer.
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index caa0b7b..4542d17 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -20,10 +20,9 @@
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
-
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e540cc0..bdbb883 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_protos_grappler",
+)
cc_library(
name = "scc",
@@ -210,3 +214,28 @@
"//tensorflow/core:testlib",
],
)
+
+cc_library(
+ name = "symbolic_shapes",
+ srcs = ["symbolic_shapes.cc"],
+ hdrs = ["symbolic_shapes.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + tf_protos_grappler(),
+)
+
+tf_cc_test(
+ name = "symbolic_shapes_test",
+ srcs = ["symbolic_shapes_test.cc"],
+ deps = [
+ ":symbolic_shapes",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc
similarity index 98%
rename from tensorflow/core/grappler/optimizers/symbolic_shapes.cc
rename to tensorflow/core/grappler/utils/symbolic_shapes.cc
index 155843a..1666de4 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h
similarity index 94%
rename from tensorflow/core/grappler/optimizers/symbolic_shapes.h
rename to tensorflow/core/grappler/utils/symbolic_shapes.h
index ace7bd1..0a7d8ac 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.h
@@ -13,8 +13,8 @@
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -74,4 +74,4 @@
} // namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
similarity index 98%
rename from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
rename to tensorflow/core/grappler/utils/symbolic_shapes_test.cc
index 7ce995d..6ac644c 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ef176a7..08245e6 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -51,6 +51,10 @@
"tf_kernel_tests_linkstatic",
)
load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"if_mkl_ml",
@@ -637,6 +641,7 @@
":reshape_op",
":reverse_op",
":reverse_sequence_op",
+ ":searchsorted_op",
":shape_ops",
":slice_op",
":snapshot_op",
@@ -870,6 +875,12 @@
)
tf_kernel_library(
+ name = "searchsorted_op",
+ prefix = "searchsorted_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "inplace_ops",
prefix = "inplace_ops",
deps = ARRAY_DEPS,
@@ -1106,7 +1117,7 @@
name = "depthwise_conv_ops_test",
size = "small",
srcs = ["depthwise_conv_ops_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":conv_ops",
":image",
@@ -2703,6 +2714,7 @@
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2760,6 +2772,7 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -4397,6 +4410,7 @@
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
":string_length_op",
":string_split_op",
@@ -4428,6 +4442,30 @@
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e..7d09e9b 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06..4ae26fb 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
- example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
- resource->node_value(0, 0));
+ float tree_logit =
+ resource->GetTreeWeight(0) * resource->node_value(0, 0);
+ example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
- float tree_logit;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
- // Populate proto.
+ // Go through each tree and populate proto.
while (tree_id <= last_tree) {
- // Feature id used to split.
- feature_id = resource->feature_id(tree_id, node_id);
- example_debug_info.add_feature_ids(feature_id);
- // Get logit after split.
- node_id = resource->next_node(tree_id, node_id, i,
- batch_bucketized_features);
- tree_logit = resource->GetTreeWeight(tree_id) *
- resource->node_value(tree_id, node_id);
- // Output logit incorporates sum of leaf logits from prior trees.
- example_debug_info.add_logits_path(tree_logit + past_trees_logit);
- if (resource->is_leaf(tree_id, node_id)) {
- // Move onto other trees.
- past_trees_logit += tree_logit;
+ if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
+ // Accumulate tree_logits only if the leaf is non-root, but do so
+ // for bias tree.
+ if (tree_id == 0 || node_id > 0) {
+ past_trees_logit += tree_logit;
+ }
++tree_id;
node_id = 0;
+ } else { // Add to proto.
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index b3ab7e2..0bb929b 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -27,6 +27,74 @@
namespace tensorflow {
namespace data {
+namespace {
+
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+ void IncrementProcessingTime(int64 delta) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+ return new SimpleNodeExecStats(this);
+ }
+
+ string ReportAllocsOnResourceExhausted(const string& err) override {
+ return "";
+ }
+
+ int64 processing_time() {
+ tf_shared_lock l(mu_);
+ return processing_time_;
+ }
+
+ private:
+ class SimpleNodeExecStats : public NodeExecStatsInterface {
+ public:
+ explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+ : step_stats_collector_(step_stats_collector) {}
+
+ void Done(const string& device) override {
+ step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+ start_time_ns_);
+ delete this;
+ }
+
+ void RecordExecutorStarted() override {
+ start_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void RecordComputeStarted() override {}
+
+ void RecordComputeEnded() override {}
+
+ void RecordExecutorEnded() override {
+ end_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void SetMemory(OpKernelContext* ctx) override {}
+
+ void SetOutput(int slot, const Tensor* tensor) override {}
+
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+ void SetScheduled(int64 nanos) override {}
+
+ private:
+ int64 start_time_ns_ = 0;
+ int64 end_time_ns_ = 0;
+ SimpleStepStatsCollector* step_stats_collector_; // Not owned.
+ };
+
+ mutex mu_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+} // namespace
+
/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
@@ -359,13 +427,13 @@
done(s);
return;
}
- auto frame =
+ OwnedArgsCallFrame* frame =
new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- auto step_container = new ScopedStepContainer(
+ ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
});
@@ -380,24 +448,19 @@
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
+ CancellationManager* c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
- StepStats* stats = nullptr;
- StepStatsCollector* stats_collector = nullptr;
- std::shared_ptr<model::Node> node;
+ std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (ctx->model()) {
- node = ctx->model()->LookupNode(prefix);
- if (node) {
- // TODO(b/114104975): Use something light-weight here.
- stats = new StepStats();
- stats_collector = new StepStatsCollector(stats);
- }
+ stats_collector = MakeUnique<SimpleStepStatsCollector>();
}
- f_opts.stats_collector = stats_collector;
+ f_opts.stats_collector = stats_collector.get();
auto callback = std::bind(
- [rets, step_container, c_mgr, frame, stats, stats_collector, node](
- FunctionLibraryRuntime::DoneCallback done,
+ [rets, step_container, c_mgr, frame](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<model::Model>& model, const string& prefix,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
delete step_container;
@@ -406,25 +469,17 @@
s = frame->ConsumeRetvals(rets);
}
delete frame;
- if (node) {
- int64 delta = 0;
- stats_collector->Finalize();
- for (auto dev_stats : stats->dev_stats()) {
- for (auto node_stats : dev_stats.node_stats()) {
- delta += node_stats.all_end_rel_nanos();
- }
- }
- delete stats_collector;
- delete stats;
- node->add_processing_time(delta);
- node->start_work();
+ if (model) {
+ model->AddProcessingTime(prefix, stats_collector->processing_time());
+ model->RecordStart(prefix, false /* stop_output */);
}
done(s);
- if (node) {
- node->stop_work();
+ if (model) {
+ model->RecordStop(prefix, false /* start_output */);
}
},
- std::move(done), std::placeholders::_1);
+ std::move(done), ctx->model(), prefix, std::move(stats_collector),
+ std::placeholders::_1);
ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 19c35f9..0088431 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -14,11 +14,13 @@
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace data {
@@ -139,7 +141,13 @@
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params) {}
+ : DatasetIterator<FilterDatasetBase>(params),
+ filtered_elements_(0),
+ dropped_elements_(0) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -154,6 +162,7 @@
// `input_impl_` and `f` are thread-safe. However, if multiple
// threads enter this method, outputs may be observed in a
// non-deterministic order.
+ auto stats_aggregator = ctx->stats_aggregator();
bool matched;
do {
{
@@ -176,8 +185,34 @@
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ dropped_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::dropped_elements"),
+ static_cast<float>((dropped_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect
+ // aggregated number of dropped elements for all the pipelines,
+ // exploit tagged_context here.
+ stats_aggregator->IncrementCounter(
+ prefix_end_, "dropped_elements", static_cast<float>(1));
+ }
}
} while (!matched);
+ // TODO(shivaniagrawal): add ratio of dropped_elements and
+ // filtered_elements as a histogram.
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ filtered_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::filtered_elements"),
+ static_cast<float>((filtered_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect aggregated
+ // number of filtered elements for all the pipelines, exploit
+ // tagged_context here.
+ stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements",
+ static_cast<float>(1));
+ }
*end_of_sequence = false;
return Status::OK();
}
@@ -190,6 +225,10 @@
else
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"),
+ filtered_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"),
+ dropped_elements_));
return Status::OK();
}
@@ -200,12 +239,19 @@
input_impl_.reset();
else
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"),
+ &filtered_elements_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"),
+ &dropped_elements_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ int64 filtered_elements_ GUARDED_BY(mu_);
+ int64 dropped_elements_ GUARDED_BY(mu_);
+ string prefix_end_;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 8389621..2bbf4af 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <atomic>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@@ -202,17 +203,9 @@
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
- std::function<void(int64)> set_fn = [this](int64 value) {
- {
- mutex_lock l(mu_);
- num_parallel_calls_ = value;
- }
- VLOG(2) << "setting parallelism knob to " << value;
- cond_var_.notify_all();
- };
- AddTunableParameter(
- ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
- port::NumSchedulableCPUs() /* max */, std::move(set_fn));
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ port::NumSchedulableCPUs() /* max */, &cond_var_);
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
@@ -230,14 +223,14 @@
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
@@ -340,11 +333,9 @@
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- result->num_calls--;
- }
+ mutex_lock l(mu_);
+ num_calls_--;
+ result->num_calls--;
cond_var_.notify_all();
}
@@ -437,11 +428,6 @@
result->output_allocated = true;
}
- int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (num_parallel_calls_ + dataset()->batch_size_ - 1) /
- dataset()->batch_size_;
- }
-
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
@@ -490,34 +476,34 @@
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- StartWork(ctx.get());
+ RecordStart(ctx.get());
auto stop_cleanup =
- gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
- {
- tf_shared_lock l(mu_);
- new_calls.reserve(num_parallel_calls_);
- }
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+ new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ int64 max_batch_results =
+ (num_parallel_calls + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ return num_calls_ >= num_parallel_calls ||
+ (batch_results_.size() > max_batch_results ||
+ (batch_results_.size() == max_batch_results &&
+ call_counter_ % dataset()->batch_size_ == 0));
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
- StopWork(ctx.get());
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
+ while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
batch_results_.emplace_back(
new BatchResult(dataset()->batch_size_));
@@ -662,7 +648,7 @@
// the `batch_results_` buffer.
condition_variable cond_var_;
// Identifies the maximum number of parallel calls.
- int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index 63025d3..9aa505f 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -77,6 +77,14 @@
: DatasetIterator<Dataset>(params),
model_(std::make_shared<model::Model>()) {}
+ ~Iterator() override {
+ // Signal the optimize thread to terminate it. We will then join that
+ // thread when we delete `this->optimize_thread_`.
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ cond_var_.notify_all();
+ }
+
Status Initialize(IteratorContext* ctx) override {
IteratorContext ctx_with_model(CreateParams(ctx));
return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
@@ -87,21 +95,7 @@
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- int64 now = ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
- if (last_optimization_ms_ + optimization_period_ms_ < now) {
- model_->Optimize(port::NumSchedulableCPUs());
- // Exponentially increase the period of running the optimization until
- // a threshold is reached.
- if (optimization_period_ms_ < kOptimizationPeriodThresholdMs) {
- if (optimization_period_ms_ << 1 < kOptimizationPeriodThresholdMs) {
- optimization_period_ms_ <<= 1;
- } else {
- optimization_period_ms_ = kOptimizationPeriodThresholdMs;
- }
- }
- last_optimization_ms_ =
- ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
- }
+ TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx));
IteratorContext ctx_with_model(CreateParams(ctx));
return input_impl_->GetNext(&ctx_with_model, out_tensors,
end_of_sequence);
@@ -128,10 +122,53 @@
}
private:
+ Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!optimize_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ optimize_thread_.reset(ctx->env()->StartThread(
+ {}, "optimize_thread",
+ [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ }
+ return Status::OK();
+ }
+
+ void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+ int64 last_optimization_ms = 0;
+ int64 optimization_period_ms = 10;
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ last_optimization_ms + optimization_period_ms >=
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
+ cond_var_.wait_for(
+ l, std::chrono::milliseconds(
+ last_optimization_ms + optimization_period_ms -
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
+ }
+ if (cancelled_) return;
+ }
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization
+ // until a threshold is reached.
+ if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms <<= 1;
+ } else {
+ optimization_period_ms = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
+ }
+
mutex mu_;
+ condition_variable cond_var_;
std::shared_ptr<model::Model> model_;
- int64 last_optimization_ms_ GUARDED_BY(mu_) = 0;
- int64 optimization_period_ms_ GUARDED_BY(mu_) = 10;
+ std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 9cd46bf..2e6e046 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <atomic>
#include <deque>
#include <utility>
@@ -344,13 +345,13 @@
if (must_wait_for_input) {
// Wait for elements to become available.
- StopWork(ctx);
+ RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
- StartWork(ctx);
+ RecordStart(ctx);
}
}
return errors::Cancelled(
@@ -618,11 +619,11 @@
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- StartWork(ctx.get());
+ RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
- StopWork(ctx.get());
+ RecordStop(ctx.get());
});
bool make_new_iterator;
{
@@ -660,9 +661,9 @@
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -712,9 +713,9 @@
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -763,9 +764,9 @@
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) return;
@@ -1213,10 +1214,10 @@
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
- num_parallel_calls_(params.dataset->num_parallel_calls_),
thread_pool_(new thread::ThreadPool(
Env::Default(), ThreadOptions(), "parallel_interleave",
dataset()->cycle_length_ /* num_threads */,
@@ -1237,17 +1238,9 @@
mutex_lock l(mu_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
- auto set_fn = [this](int64 value) {
- {
- mutex_lock l(mu_);
- num_parallel_calls_ = value;
- }
- VLOG(2) << "setting parallelism knob to " << value;
- cond_var_.notify_all();
- };
- AddTunableParameter(
- ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
- dataset()->cycle_length_ /* max */, std::move(set_fn));
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ dataset()->cycle_length_ /* max */, &cond_var_);
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
@@ -1267,9 +1260,9 @@
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
@@ -1278,11 +1271,11 @@
*end_of_sequence = true;
return Status::OK();
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
- StopWork(ctx);
+ RecordStop(ctx);
result->notification.WaitForNotification();
- StartWork(ctx);
+ RecordStart(ctx);
} while (result->skip);
if (result->status.ok()) {
@@ -1406,8 +1399,8 @@
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
@@ -1425,60 +1418,66 @@
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
- {
- if (end_of_input) {
- current_elements_[cycle_index].reset();
- }
- mutex_lock l(mu_);
- element_in_use_[cycle_index] = false;
- num_calls_--;
- if (end_of_input) {
- args_list_[cycle_index].clear();
- num_open_--;
- }
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
}
cond_var_.notify_all();
}
- int64 MaxInvocationResults() {
- return dataset()->cycle_length_ * dataset()->block_length_;
- }
-
// Method responsible for 1) creating iterators out of input elements, 2)
// determining the order in which elements are fetched from the iterators,
// and 3) scheduling the fetching of the elements to a threadpool.
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ return element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >=
+ dataset()->cycle_length_ * dataset()->block_length_;
+ };
while (true) {
- {
- mutex_lock l(mu_);
- // Wait until this thread is cancelled, the end of input has been
- // reached, or the cycle element at the `cycle_index_` position is
- // not in use and there is space in the `invocation_results_` queue.
- while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
- (element_in_use_[cycle_index_] ||
- num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- StopWork(ctx.get());
- cond_var_.wait(l);
- StartWork(ctx.get());
- }
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
+ RecordStop(ctx.get());
+ cond_var_.wait(l);
+ RecordStart(ctx.get());
+ }
- if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
- return;
- }
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
- while (!element_in_use_[cycle_index_] &&
- (!end_of_input_ || num_open_ > 0) &&
- num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- if (!current_elements_[cycle_index_]) {
- // Try to create a new iterator from the next input element.
- Status status = input_impl_->GetNext(
- ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ while ((!end_of_input_ || num_open_ > 0) && !busy()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ ¤t_elements_[cycle_index_]);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
@@ -1487,39 +1486,25 @@
result->notification.Notify();
break;
}
- if (!end_of_input_) {
- Status status = MakeIteratorFromInputElement(
- ctx.get(), args_list_[cycle_index_], cycle_index_,
- dataset()->captured_func_.get(), prefix(),
- ¤t_elements_[cycle_index_]);
- if (!status.ok()) {
- invocation_results_.emplace_back(new InvocationResult());
- std::shared_ptr<InvocationResult>& result =
- invocation_results_.back();
- result->status.Update(status);
- result->notification.Notify();
- break;
- }
- ++num_open_;
- }
+ ++num_open_;
}
- if (current_elements_[cycle_index_]) {
- // Pre-allocate invocation results for outputs to be fetched
- // and then fetch the outputs asynchronously.
- std::vector<std::shared_ptr<InvocationResult>> results;
- results.reserve(dataset()->block_length_);
- for (int i = 0; i < dataset()->block_length_; ++i) {
- invocation_results_.emplace_back(new InvocationResult());
- results.push_back(invocation_results_.back());
- }
- num_calls_++;
- element_in_use_[cycle_index_] = true;
- thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
- ctx, cycle_index_,
- std::move(results)));
- }
- cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
+ }
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
+ }
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
cond_var_.notify_all();
}
@@ -1622,6 +1607,9 @@
// and there are elements left to be fetched.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
+
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
@@ -1648,9 +1636,6 @@
// Identifies the number of open iterators.
int64 num_open_ GUARDED_BY(mu_) = 0;
- // Identifies the maximum number of parallel calls.
- int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
-
// Identifies the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 5f6052c..ee20249b 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include <atomic>
#include <deque>
#include <functional>
#include <utility>
@@ -40,11 +41,6 @@
num_parallel_calls_(num_parallel_calls) {}
~ParallelMapIterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
mutex_lock l(mu_);
// Cancel the runner thread.
cancelled_ = true;
@@ -59,19 +55,11 @@
mutex_lock l(mu_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
- auto set_fn = [this](int64 value) {
- {
- mutex_lock l(mu_);
- num_parallel_calls_ = value;
- }
- VLOG(2) << "setting parallelism knob to " << value;
- cond_var_.notify_all();
- };
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
- AddTunableParameter(ctx, "parallelism", num_parallel_calls_ /* value */,
+ AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
1 /* min */, port::NumSchedulableCPUs() /* max */,
- std::move(set_fn));
+ &cond_var_);
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
@@ -90,17 +78,17 @@
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
- StopWork(ctx);
+ RecordStop(ctx);
result->notification.WaitForNotification();
- StartWork(ctx);
+ RecordStart(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
@@ -201,9 +189,9 @@
{
mutex_lock l(mu_);
num_calls_--;
+ cond_var_.notify_all();
}
result->notification.Notify();
- cond_var_.notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
@@ -218,9 +206,8 @@
return;
}
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification` to unblock
- // a consumer.
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
@@ -249,34 +236,33 @@
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- {
- tf_shared_lock l(mu_);
- new_calls.reserve(num_parallel_calls_);
- }
+ new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ return num_calls_ >= num_parallel_calls ||
+ invocation_results_.size() >= num_parallel_calls;
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= num_parallel_calls_)) {
- StopWork(ctx.get());
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < num_parallel_calls_) {
+ while (!busy()) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
@@ -334,7 +320,7 @@
// buffer.
condition_variable cond_var_;
// Identifies the maximum number of parallel calls.
- int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 52c421c..754ed77 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -103,18 +103,18 @@
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ auto stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(mu_);
- auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
if (cancelled_) {
@@ -136,6 +136,14 @@
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@@ -219,6 +227,12 @@
strings::StrCat(prefix_end_, "::buffer_utilization"),
{static_cast<float>(buffer_.size()) /
static_cast<float>(auto_tuner_.buffer_limit())});
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
}
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
@@ -255,8 +269,8 @@
//
// It owns the iterator context passed to it.
void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -264,9 +278,9 @@
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
cond_var_.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) {
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459..76afd6f 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a977..e7882ac 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0..374a058 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e..1ded012 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -90,6 +91,59 @@
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f..a259d99 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788..82dfece 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
- // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
+ // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
// us using the same state in different batches.
gen_copy.Skip(start_row * (num_samples + 3) / 4);
random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92..272aa3b 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@
#include <deque>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@
// NOTE(mrry): This method is deprecated. Use
// `tensorflow::batch_util::CopySliceToElement()` defined in
// "./batch_util.h" instead.
+ ABSL_DEPRECATED(
+ "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+ "\"./batch_util.h\" instead.")
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int64 index);
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 88b3c2a..bb8254e 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -21,11 +21,11 @@
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index e4ca89e..5318d8c 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,15 +76,7 @@
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-REGISTER_KERNEL_BUILDER(
- Name("Sum")
- .Device(DEVICE_GPU)
- .TypeConstraint<int64>("T")
- .TypeConstraint<int32>("Tidx")
- .HostMemory("input")
- .HostMemory("output")
- .HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
+
#endif
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000..dc627ac
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+ explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+ explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h
new file mode 100644
index 0000000..f075bf0
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, typename OutType>
+struct UpperBoundFunctor {
+ // Searches for values in sorted_inputs and returns the greatest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+
+template <typename Device, typename T, typename OutType>
+struct LowerBoundFunctor {
+ // Searches for values in sorted_inputs and returns the lowest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+} // namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
new file mode 100644
index 0000000..263b5bf
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+template <typename T, typename OutType>
+__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+
+template <typename T, typename OutType>
+__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+} // namespace
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ UpperBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ LowerBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000..e4a1887
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000..13130a5
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5..2fbe1fe 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8879d9d..2255597 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -21,10 +21,10 @@
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 99684ae..9ccd911 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@@ -120,6 +121,54 @@
impl_->Schedule(std::move(fn));
}
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+ const int64 block_size, const int64 total) {
+ if (block_size <= 0 || total <= 1 || total <= block_size ||
+ NumThreads() == 1) {
+ return 1;
+ }
+ return (total + block_size - 1) / block_size;
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::TransformRangeConcurrently(
+ const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn) {
+ const int num_shards_used =
+ NumShardsUsedByTransformRangeConcurrently(block_size, total);
+ if (num_shards_used == 1) {
+ fn(0, total);
+ return;
+ }
+
+ // Adapted from Eigen's parallelFor implementation.
+ BlockingCounter counter(num_shards_used);
+ std::function<void(int64, int64)> handle_range =
+ [=, &handle_range, &counter, &fn](int64 first, int64 last) {
+ while (last - first > block_size) {
+ // Find something near the midpoint which is a multiple of block size.
+ const int64 mid = first + ((last - first) / 2 + block_size - 1) /
+ block_size * block_size;
+ Schedule([=, &handle_range]() { handle_range(mid, last); });
+ last = mid;
+ }
+ // Single block or less, execute directly.
+ fn(first, last);
+ counter.DecrementCount(); // The shard is done.
+ };
+ if (num_shards_used <= NumThreads()) {
+ // Avoid a thread hop by running the root of the tree and one block on the
+ // main thread.
+ handle_range(0, total);
+ } else {
+ // Execute the root in the thread pool to avoid running work on more than
+ // numThreads() threads.
+ Schedule([=, &handle_range]() { handle_range(0, total); });
+ }
+ counter.Wait();
+}
+
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 74df7c8..e14ad7a 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -59,6 +59,20 @@
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn);
+ // Requires 0 < block_size <= total.
+ // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
+ // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
+ // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
+ // Note that when there aren't enough threads in the pool to achieve full
+ // parallelism, function calls will be automatically queued.
+ void TransformRangeConcurrently(const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn);
+
+ // Returns the number of threads spawned by calling TransformRangeConcurrently
+ // with these parameters.
+ int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
+ const int64 total);
+
// ParallelFor shards the "total" units of work assuming each unit of work
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 320f3eb..db996b7 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -61,6 +61,67 @@
}
}
+void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ threads->TransformRangeConcurrently(
+ block_size, total,
+ [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+ VLOG(1) << "Shard [" << start << "," << end << ")";
+ EXPECT_GE(start, 0);
+ EXPECT_LE(end, total);
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < end; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ LOG(INFO) << block_size << " " << total;
+ const int64 num_workers = (total + block_size - 1) / block_size;
+ EXPECT_EQ(num_done_work, total);
+ if (num_workers < threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + num_workers);
+ }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(SparseUtilsTest, TransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ const int64 total = block_size + diff;
+ RunSharding(block_size, total, &threads);
+ }
+ }
+}
+
+TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 3 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 4 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 5 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 6 /* total */));
+ EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 7 /* total */));
+ EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently(
+ 1 /* block_size */, 7 /* total */));
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 0 /* block_size */, 7 /* total */));
+}
+
TEST(ThreadPool, ParallelFor) {
Context outer_context(ContextKind::kThread);
// Make ParallelFor use as many threads as possible.
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 7dbb18a..c249506 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2916,6 +2916,34 @@
} // namespace
+REGISTER_OP("UpperBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
+REGISTER_OP("LowerBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
REGISTER_OP("ScatterNd")
.Input("indices: Tindices")
.Input("updates: T")
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index e599587..e30a111 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -29388,6 +29388,38 @@
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -38880,6 +38912,30 @@
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -70188,6 +70244,43 @@
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -75267,6 +75360,38 @@
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a8..f84142c 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -52,11 +51,16 @@
.Attr("seed2: int = 0")
.Output("params_size: S")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // num_layers, num_units, and input_size should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
c->set_output(0, c->Vector(1));
return Status::OK();
});
-
REGISTER_OP("CudnnRNN")
.Input("input: T")
.Input("input_h: T")
@@ -248,7 +252,6 @@
return Status::OK();
});
-
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd8675..13c3b93 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
+ INFER_OK(op, "[];[];[]", "[1]");
+ INFER_OK(op, "?;[];[]", "[1]");
+ INFER_OK(op, "[];?;[]", "[1]");
+ INFER_OK(op, "[];[];?", "[1]");
+ INFER_OK(op, "[];?;?", "[1]");
+ INFER_OK(op, "?;?;?", "[1]");
+
+ INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 639d211..2034d36 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -20,6 +20,8 @@
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("Assert")
.Input("condition: bool")
.Input("data: T")
@@ -44,6 +46,23 @@
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+REGISTER_OP("PrintV2")
+ .Input("input: string")
+ .SetIsStateful()
+ .Attr(
+ "output_stream: {'stdout', 'stderr', 'log(info)', "
+ "'log(warning)', 'log(error)'} = 'stderr'")
+ .SetShapeFn([](InferenceContext* c) {
+ // Make sure that the input is a scalar.
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("input must be a scalar, but has rank: ",
+ c->Rank(c->input(0)));
+ }
+ return Status::OK();
+ });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 4ece1c8..594edfd 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14536,6 +14536,38 @@
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -19521,6 +19553,30 @@
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -32735,6 +32791,43 @@
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -35954,6 +36047,38 @@
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index ef8b15d..9915983 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@
limitations under the License.
==============================================================================*/
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,32 @@
.Attr("fill: string = ''")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringFormat")
+ .Input("inputs: T")
+ .Output("output: string")
+ .Attr("T: list(type) >= 0")
+ .Attr("template: string = '%s'")
+ .Attr("placeholder: string = '%s'")
+ .Attr("summarize: int = 3")
+ .SetShapeFn([](InferenceContext* c) {
+ string template_;
+ string placeholder;
+ TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+ TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+ std::vector<std::string> split_template;
+ split_template = absl::StrSplit(template_, placeholder);
+ int64 num_placeholders = split_template.size() - 1;
+ if (c->num_inputs() != num_placeholders) {
+ return errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", c->num_inputs()));
+ }
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("StringJoin")
.Input("inputs: N * string")
.Attr("N: int")
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 83228fa..83ea853 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -25,6 +25,7 @@
#ifdef _WIN32
#include <io.h> // for _mktemp
#endif
+#include "absl/base/macros.h"
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -63,7 +64,7 @@
// The HTTP response code "308 Resume Incomplete".
constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
// The environment variable that overrides the size of the readahead buffer.
-// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
+ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.")
constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
// The environment variable that disables the GCS block cache for reads.
// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 3a012c2..37475fe 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -3,64 +3,64 @@
# be separate to avoid cyclic references.
def tf_cuda_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_sycl_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_additional_plugin_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): [
- str(Label("//tensorflow/compiler/jit"))
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): [
+ str(Label("//tensorflow/compiler/jit")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_xla_deps_py():
- return []
+ return []
def tf_additional_grpc_deps_py():
- return []
+ return []
def tf_additional_license_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
+ "//conditions:default": [],
+ })
def tf_additional_verbs_deps():
- return select({
- str(Label("//tensorflow:with_verbs_support")): [
- str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
- str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_verbs_support")): [
+ str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
+ str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_deps():
- return select({
- str(Label("//tensorflow:with_mpi_support")): [
- str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_mpi_support")): [
+ str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_deps():
- return select({
- str(Label("//tensorflow:with_gdr_support")): [
- str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_gdr_support")): [
+ str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
-def if_static(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:framework_shared_object")): otherwise,
- "//conditions:default": extra_deps,
- })
+def if_static(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:framework_shared_object")): otherwise,
+ "//conditions:default": extra_deps,
+ })
-def if_dynamic_kernels(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
- "//conditions:default": otherwise,
- })
+def if_dynamic_kernels(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+ "//conditions:default": otherwise,
+ })
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
index 1ab6821..5823374 100644
--- a/tensorflow/core/platform/default/cord.h
+++ b/tensorflow/core/platform/default/cord.h
@@ -16,9 +16,6 @@
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
-class Cord;
-namespace absl {
-using ::Cord;
-} // namespace absl
+// TODO(ebrevdo): Fill this in.
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 30059dc..156af6c 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -255,10 +255,13 @@
/// \brief Append 'data' to the file.
virtual Status Append(StringPiece data) = 0;
+ // TODO(ebrevdo): Remove this ifdef when absl is updated.
+#if defined(PLATFORM_GOOGLE)
// \brief Append 'data' to the file.
virtual Status Append(const absl::Cord& cord) {
return errors::Unimplemented("Append(absl::Cord) is not implemented");
}
+#endif
/// \brief Close the file.
///
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 625d564..85cd023 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -68,7 +68,7 @@
// after the process starts. Users are required to use vendor
// specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the
// physical to visible device mapping prior to invoking TensorFlow.
- // 2. In the code, the ids in this list are also called "CUDA GPU id"s,
+ // 2. In the code, the ids in this list are also called "platform GPU id"s,
// and the 'virtual' ids of GPU devices (i.e. the ids in the device
// name "/device:GPU:<id>") are also called "TF GPU id"s. Please
// refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h
diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto
new file mode 100644
index 0000000..7644314
--- /dev/null
+++ b/tensorflow/core/protobuf/replay_log.proto
@@ -0,0 +1,47 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/protobuf/cluster.proto";
+import "tensorflow/core/protobuf/master.proto";
+
+// Records the creation of a new replay session. We record the device listing
+// here to capture the state of the cluster.
+message NewReplaySession {
+ ListDevicesResponse devices = 1;
+ string session_handle = 2;
+}
+
+message ReplayOp {
+ double start_time_us = 31;
+ double end_time_us = 32;
+
+ oneof op {
+ CreateSessionRequest create_session = 1;
+ ExtendSessionRequest extend_session = 2;
+ PartialRunSetupRequest partial_run_setup = 3;
+ RunStepRequest run_step = 4;
+ CloseSessionRequest close_session = 5;
+ ListDevicesRequest list_devices = 6;
+ ResetRequest reset_request = 7;
+ MakeCallableRequest make_callable = 8;
+ RunCallableRequest run_callable = 9;
+ ReleaseCallableRequest release_callable = 10;
+ NewReplaySession new_replay_session = 11;
+ }
+
+ oneof response {
+ CreateSessionResponse create_session_response = 21;
+ ExtendSessionResponse extend_session_response = 22;
+ PartialRunSetupResponse partial_run_setup_response = 23;
+ RunStepResponse run_step_response = 24;
+ CloseSessionResponse close_session_response = 25;
+ ListDevicesResponse list_devices_response = 26;
+ ResetResponse reset_request_response = 27;
+ MakeCallableResponse make_callable_response = 28;
+ RunCallableResponse run_callable_response = 29;
+ ReleaseCallableResponse release_callable_response = 30;
+ }
+}
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 540adb5..f6f0408 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -93,11 +93,11 @@
}
namespace cuda_helper {
-template <typename IntType>
-__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
- IntType* orig = first;
- IntType* it = nullptr;
- IntType step = 0;
+template <typename T, typename OutType = int32>
+__device__ OutType upper_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
while (count > 0) {
it = first;
step = count / 2;
@@ -112,6 +112,27 @@
return first - orig;
}
+
+template <typename T, typename OutType = int32>
+__device__ OutType lower_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
+ while (count > 0) {
+ it = first;
+ step = count / 2;
+ it += step;
+ if (*it < val) {
+ first = ++it;
+ count -= step + 1;
+ } else {
+ count = step;
+ }
+ }
+
+ return first - orig;
+}
+
} // namespace cuda_helper
} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 0f04b65..b9ca8ab 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -20,6 +20,7 @@
#include <numeric>
#include <vector>
+#include "absl/base/macros.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -95,21 +96,21 @@
SparseTensor() : dims_(0) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
@@ -237,9 +238,10 @@
static Status Split(const SparseTensor& tensor, const int split_dim,
const int num_split, std::vector<SparseTensor>* result);
- // DEPRECATED: use the form of Split() that takes an output pointer and
- // returns a status instead.
template <typename T>
+ ABSL_DEPRECATED(
+ "Use the form of Split() that takes an output pointer and returns a "
+ "status instead.")
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
const int num_split,
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index f4bd295..74f0713 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -50,6 +50,8 @@
max_parallelism);
}
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size.
void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
const Runner& runner, int max_parallelism) {
cost_per_unit = std::max(int64{1}, cost_per_unit);
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index b12c31c..9db85a5 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -23,6 +23,9 @@
namespace tensorflow {
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size. Use this function only if you want to
+// manually cap parallelism.
// Shards the "total" unit of work assuming each unit of work having
// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
// total - 1. Each shard contains 1 or more units of work and the
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index d4070fd..99da44d 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -84,6 +84,18 @@
)
py_binary(
+ name = "mnist_softmax_xla",
+ srcs = [
+ "mnist_softmax_xla.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
name = "mnist_deep",
srcs = [
"mnist_deep.py",
diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md
index 288a325..3989f9b 100644
--- a/tensorflow/go/README.md
+++ b/tensorflow/go/README.md
@@ -10,7 +10,7 @@
## Quickstart
-Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go)
+Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/lang_go)
## Building the TensorFlow C library from source
@@ -23,9 +23,7 @@
- [bazel](https://www.bazel.build/versions/master/docs/install.html)
- Environment to build TensorFlow from source code
- ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
- or [OS
- X](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+ ([Linux of macOS](https://www.tensorflow.org/install/source)).
If you don't need GPU support, then try the following:
```sh
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index eb636db..1d72bcd 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3741,98 +3741,28 @@
return op.Output(0)
}
-// Computes the sum along sparse segments of a tensor.
+// Makes the summary of accumulated stats for the batch.
//
-// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
-//
-// For example:
-//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
-//
-// tf.sparse_segment_sum_with_num_segments(
-// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
-// # => [[0 0 0 0]
-// # [0 0 0 0]
-// # [0 0 0 0]]
-//
-// tf.sparse_segment_sum_with_num_segments(c,
-// tf.constant([0, 1]),
-// tf.constant([0, 2],
-// num_segments=4))
-// # => [[ 1 2 3 4]
-// # [ 0 0 0 0]
-// # [-1 -2 -3 -4]
-// # [ 0 0 0 0]]
-// ```
+// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
//
// Arguments:
+// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+// max_splits: int; the maximum number of splits possible in the whole tree.
+// num_buckets: int; equals to the maximum possible value of bucketized feature.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
opspec := tf.OpSpec{
- Type: "SparseSegmentSumWithNumSegments",
+ Type: "BoostedTreesMakeStatsSummary",
Input: []tf.Input{
- data, indices, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// PreventGradientAttr is an optional argument to PreventGradient.
-type PreventGradientAttr func(optionalAttr)
-
-// PreventGradientMessage sets the optional message attribute to value.
-//
-// value: Will be printed in the error when anyone tries to differentiate
-// this operation.
-// If not specified, defaults to ""
-func PreventGradientMessage(value string) PreventGradientAttr {
- return func(m optionalAttr) {
- m["message"] = value
- }
-}
-
-// An identity op that triggers an error if a gradient is requested.
-//
-// When executed in a graph, this op outputs its input tensor as-is.
-//
-// When building ops to compute gradients, the TensorFlow gradient system
-// will return an error when trying to lookup the gradient of this op,
-// because no gradient must ever be registered for this function. This
-// op exists to prevent subtle bugs from silently returning unimplemented
-// gradients in some corner cases.
-//
-// Arguments:
-// input: any tensor.
-//
-// Returns the same input tensor.
-func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "PreventGradient",
- Input: []tf.Input{
- input,
+ node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
},
Attrs: attrs,
}
@@ -3840,21 +3770,6 @@
return op.Output(0)
}
-// Computes asin of x element-wise.
-func Asin(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Asin",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the sum along sparse segments of a tensor.
//
// Read
@@ -4564,104 +4479,6 @@
return op.Output(0)
}
-// NthElementAttr is an optional argument to NthElement.
-type NthElementAttr func(optionalAttr)
-
-// NthElementReverse sets the optional reverse attribute to value.
-//
-// value: When set to True, find the nth-largest value in the vector and vice
-// versa.
-// If not specified, defaults to false
-func NthElementReverse(value bool) NthElementAttr {
- return func(m optionalAttr) {
- m["reverse"] = value
- }
-}
-
-// Finds values of the `n`-th order statistic for the last dimension.
-//
-// If the input is a vector (rank-1), finds the entries which is the nth-smallest
-// value in the vector and outputs their values as scalar tensor.
-//
-// For matrices (resp. higher rank input), computes the entries which is the
-// nth-smallest value in each row (resp. vector along the last dimension). Thus,
-//
-// values.shape = input.shape[:-1]
-//
-// Arguments:
-// input: 1-D or higher with last dimension at least `n+1`.
-// n: 0-D. Position of sorted vector to select along the last dimension (along
-// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
-//
-// Returns The `n`-th order statistic along each last dimensional slice.
-func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "NthElement",
- Input: []tf.Input{
- input, n,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the maximum along segments of a tensor.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
-//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the maximum such that:
-//
-// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
-// that `segment_ids[j...] == i`.
-//
-// If the maximum is empty for a given segment ID `i`, it outputs the smallest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::lowest()`.
-//
-// If the given segment ID `i` is negative, then the corresponding value is
-// dropped, and will not be included in the result.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A tensor whose shape is a prefix of `data.shape`.END
-// }
-// out_arg {
-// name: "output"
-// description: <<END
-// Has same shape as data, except for the first `segment_ids.rank`
-// dimensions, which are replaced with a single dimension which has size
-// `num_segments`.
-//
-func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "UnsortedSegmentMax",
- Input: []tf.Input{
- data, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes exponential of x element-wise. \\(y = e^x\\).
func Exp(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -4769,6 +4586,218 @@
return op.Output(0), op.Output(1), op.Output(2)
}
+// PreventGradientAttr is an optional argument to PreventGradient.
+type PreventGradientAttr func(optionalAttr)
+
+// PreventGradientMessage sets the optional message attribute to value.
+//
+// value: Will be printed in the error when anyone tries to differentiate
+// this operation.
+// If not specified, defaults to ""
+func PreventGradientMessage(value string) PreventGradientAttr {
+ return func(m optionalAttr) {
+ m["message"] = value
+ }
+}
+
+// An identity op that triggers an error if a gradient is requested.
+//
+// When executed in a graph, this op outputs its input tensor as-is.
+//
+// When building ops to compute gradients, the TensorFlow gradient system
+// will return an error when trying to lookup the gradient of this op,
+// because no gradient must ever be registered for this function. This
+// op exists to prevent subtle bugs from silently returning unimplemented
+// gradients in some corner cases.
+//
+// Arguments:
+// input: any tensor.
+//
+// Returns the same input tensor.
+func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "PreventGradient",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes asin of x element-wise.
+func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Asin",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the maximum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the maximum such that:
+//
+// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the maximum is empty for a given segment ID `i`, it outputs the smallest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::lowest()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.END
+// }
+// out_arg {
+// name: "output"
+// description: <<END
+// Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+//
+func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "UnsortedSegmentMax",
+ Input: []tf.Input{
+ data, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// NthElementAttr is an optional argument to NthElement.
+type NthElementAttr func(optionalAttr)
+
+// NthElementReverse sets the optional reverse attribute to value.
+//
+// value: When set to True, find the nth-largest value in the vector and vice
+// versa.
+// If not specified, defaults to false
+func NthElementReverse(value bool) NthElementAttr {
+ return func(m optionalAttr) {
+ m["reverse"] = value
+ }
+}
+
+// Finds values of the `n`-th order statistic for the last dimension.
+//
+// If the input is a vector (rank-1), finds the entries which is the nth-smallest
+// value in the vector and outputs their values as scalar tensor.
+//
+// For matrices (resp. higher rank input), computes the entries which is the
+// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+//
+// values.shape = input.shape[:-1]
+//
+// Arguments:
+// input: 1-D or higher with last dimension at least `n+1`.
+// n: 0-D. Position of sorted vector to select along the last dimension (along
+// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
+//
+// Returns The `n`-th order statistic along each last dimensional slice.
+func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "NthElement",
+ Input: []tf.Input{
+ input, n,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor.
+//
+// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// tf.sparse_segment_sum_with_num_segments(
+// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+// # => [[0 0 0 0]
+// # [0 0 0 0]
+// # [0 0 0 0]]
+//
+// tf.sparse_segment_sum_with_num_segments(c,
+// tf.constant([0, 1]),
+// tf.constant([0, 2],
+// num_segments=4))
+// # => [[ 1 2 3 4]
+// # [ 0 0 0 0]
+// # [-1 -2 -3 -4]
+// # [ 0 0 0 0]]
+// ```
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `num_segments`.
+func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSumWithNumSegments",
+ Input: []tf.Input{
+ data, indices, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the determinant of one or more square matrices.
//
// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
@@ -9229,6 +9258,66 @@
return op.Output(0)
}
+// RandomUniformIntAttr is an optional argument to RandomUniformInt.
+type RandomUniformIntAttr func(optionalAttr)
+
+// RandomUniformIntSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random integers from a uniform distribution.
+//
+// The generated values are uniform integers in the range `[minval, maxval)`.
+// The lower bound `minval` is included in the range, while the upper bound
+// `maxval` is excluded.
+//
+// The random integers are slightly biased unless `maxval - minval` is an exact
+// power of two. The bias is small for values of `maxval - minval` significantly
+// smaller than the range of the output (either `2^32` or `2^64`).
+//
+// Arguments:
+// shape: The shape of the output tensor.
+// minval: 0-D. Inclusive lower bound on the generated integers.
+// maxval: 0-D. Exclusive upper bound on the generated integers.
+//
+// Returns A tensor of the specified shape filled with uniform random integers.
+func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomUniformInt",
+ Input: []tf.Input{
+ shape, minval, maxval,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
type ResourceApplyFtrlAttr func(optionalAttr)
@@ -11926,38 +12015,6 @@
return op.Output(0)
}
-// The gradient operator for the SparseAdd op.
-//
-// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
-// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
-// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
-// values of A and B.
-//
-// Arguments:
-// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
-// the non-empty values of the sum.
-// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
-// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
-// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
-// `[nnz(sum), ndims]`.
-//
-// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
-// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
-// non-empty values of B.
-func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseAddGrad",
- Input: []tf.Input{
- backprop_val_grad, a_indices, b_indices, sum_indices,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// String lengths of `input`.
//
// Computes the length of each string given in the input tensor.
@@ -12814,6 +12871,123 @@
return op.Output(0)
}
+// ShapeAttr is an optional argument to Shape.
+type ShapeAttr func(optionalAttr)
+
+// ShapeOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func ShapeOutType(value tf.DataType) ShapeAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Returns the shape of a tensor.
+//
+// This operation returns a 1-D integer tensor representing the shape of `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// shape(t) ==> [2, 2, 3]
+// ```
+func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Shape",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the power of one value to another.
+//
+// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+// corresponding elements in `x` and `y`. For example:
+//
+// ```
+// # tensor 'x' is [[2, 2]], [3, 3]]
+// # tensor 'y' is [[8, 16], [2, 3]]
+// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+// ```
+func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Pow",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes fingerprints of the input strings.
+//
+// Arguments:
+// input: vector of strings to compute fingerprints on.
+//
+// Returns a (N,2) shaped matrix where N is the number of elements in the input
+// vector. Each row contains the low and high parts of the fingerprint.
+func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SdcaFprint",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// The gradient operator for the SparseAdd op.
+//
+// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
+// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
+// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
+// values of A and B.
+//
+// Arguments:
+// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
+// the non-empty values of the sum.
+// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
+// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
+// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
+// `[nnz(sum), ndims]`.
+//
+// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
+// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
+// non-empty values of B.
+func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseAddGrad",
+ Input: []tf.Input{
+ backprop_val_grad, a_indices, b_indices, sum_indices,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// Computes the mean along segments of a tensor.
//
// Read
@@ -13006,6 +13180,79 @@
return op.Output(0)
}
+// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
+type RandomPoissonV2Attr func(optionalAttr)
+
+// RandomPoissonV2Seed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// RandomPoissonV2Dtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_INT64
+func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["dtype"] = value
+ }
+}
+
+// Outputs random values from the Poisson distribution(s) described by rate.
+//
+// This op uses two algorithms, depending on rate. If rate >= 10, then
+// the algorithm by Hormann is used to acquire samples via
+// transformation-rejection.
+// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+//
+// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+// random variables.
+// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+// Programming, Volume 2. Addison Wesley
+//
+// Arguments:
+// shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in rate.
+// rate: A tensor in which each scalar is a "rate" parameter describing the
+// associated poisson distribution.
+//
+// Returns A tensor with shape `shape + shape(rate)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `rate[i0, i1, ...iN]`.
+func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoissonV2",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
type DecodeAndCropJpegAttr func(optionalAttr)
@@ -20288,164 +20535,6 @@
return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
}
-// ShapeAttr is an optional argument to Shape.
-type ShapeAttr func(optionalAttr)
-
-// ShapeOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func ShapeOutType(value tf.DataType) ShapeAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Returns the shape of a tensor.
-//
-// This operation returns a 1-D integer tensor representing the shape of `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// shape(t) ==> [2, 2, 3]
-// ```
-func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Shape",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the power of one value to another.
-//
-// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
-// corresponding elements in `x` and `y`. For example:
-//
-// ```
-// # tensor 'x' is [[2, 2]], [3, 3]]
-// # tensor 'y' is [[8, 16], [2, 3]]
-// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
-// ```
-func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pow",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes fingerprints of the input strings.
-//
-// Arguments:
-// input: vector of strings to compute fingerprints on.
-//
-// Returns a (N,2) shaped matrix where N is the number of elements in the input
-// vector. Each row contains the low and high parts of the fingerprint.
-func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SdcaFprint",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
-type RandomPoissonV2Attr func(optionalAttr)
-
-// RandomPoissonV2Seed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// RandomPoissonV2Dtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_INT64
-func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["dtype"] = value
- }
-}
-
-// Outputs random values from the Poisson distribution(s) described by rate.
-//
-// This op uses two algorithms, depending on rate. If rate >= 10, then
-// the algorithm by Hormann is used to acquire samples via
-// transformation-rejection.
-// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
-//
-// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
-// random variables.
-// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
-// Programming, Volume 2. Addison Wesley
-//
-// Arguments:
-// shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in rate.
-// rate: A tensor in which each scalar is a "rate" parameter describing the
-// associated poisson distribution.
-//
-// Returns A tensor with shape `shape + shape(rate)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `rate[i0, i1, ...iN]`.
-func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoissonV2",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
type MatrixTriangularSolveAttr func(optionalAttr)
@@ -20959,66 +21048,6 @@
return op.Output(0)
}
-// RandomUniformIntAttr is an optional argument to RandomUniformInt.
-type RandomUniformIntAttr func(optionalAttr)
-
-// RandomUniformIntSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random integers from a uniform distribution.
-//
-// The generated values are uniform integers in the range `[minval, maxval)`.
-// The lower bound `minval` is included in the range, while the upper bound
-// `maxval` is excluded.
-//
-// The random integers are slightly biased unless `maxval - minval` is an exact
-// power of two. The bias is small for values of `maxval - minval` significantly
-// smaller than the range of the output (either `2^32` or `2^64`).
-//
-// Arguments:
-// shape: The shape of the output tensor.
-// minval: 0-D. Inclusive lower bound on the generated integers.
-// maxval: 0-D. Exclusive upper bound on the generated integers.
-//
-// Returns A tensor of the specified shape filled with uniform random integers.
-func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomUniformInt",
- Input: []tf.Input{
- shape, minval, maxval,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the mean along sparse segments of a tensor.
//
// Read
@@ -28116,35 +28145,6 @@
return scope.AddOperation(opspec)
}
-// Makes the summary of accumulated stats for the batch.
-//
-// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
-//
-// Arguments:
-// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
-// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
-// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
-// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
-// max_splits: int; the maximum number of splits possible in the whole tree.
-// num_buckets: int; equals to the maximum possible value of bucketized feature.
-//
-// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
-func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
- opspec := tf.OpSpec{
- Type: "BoostedTreesMakeStatsSummary",
- Input: []tf.Input{
- node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adjust the contrast of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index c7382ff..7ef862a 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -10,7 +10,7 @@
## Quickstart
-- Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/install_java)
+- Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/lang_java)
- [Javadoc](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary)
- [](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow)
@@ -22,8 +22,7 @@
1. Install [bazel](https://www.bazel.build/versions/master/docs/install.html)
2. Setup the environment to build TensorFlow from source code
- ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
- or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+ ([Linux or macOS](https://www.tensorflow.org/install/source)).
If you'd like to skip reading those details and do not care about GPU
support, try the following:
@@ -35,7 +34,7 @@
brew install swig
```
-3. [Configure](https://www.tensorflow.org/install/install_sources#configure_the_installation)
+3. [Configure](https://www.tensorflow.org/install/source)
(e.g., enable GPU support) and build:
```sh
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index cf6a64d..6c82301 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 978c3cb..f763479 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index d1378b5..7fcc6ff 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 1342b0e..689902e 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 19ff65a..ea1462a 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
index ba7e9f4..ce1ebfa 100644
--- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
+++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
@@ -6,7 +6,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>spark-tensorflow-connector_2.11</artifactId>
<packaging>jar</packaging>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<name>spark-tensorflow-connector</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
index f913faf..56346fd 100644
--- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml
+++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
@@ -5,7 +5,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<packaging>jar</packaging>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index f6cb595..93decea 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc0</version>
+ <version>1.11.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2eeae77..79f1446 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1998,6 +1998,29 @@
)
py_library(
+ name = "while_v2",
+ srcs = [
+ "ops/while_v2.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":cond_v2_impl",
+ ":constant_op",
+ ":control_flow_util",
+ ":framework_ops",
+ ":function_def_to_graph",
+ ":functional_ops_gen",
+ ":gradients_impl",
+ ":list_ops",
+ ":tensor_shape",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:function",
+ ],
+)
+
+py_library(
name = "cond_v2_impl",
srcs = [
"ops/cond_v2_impl.py",
@@ -2301,6 +2324,8 @@
deps = [
":framework_for_generated_wrappers",
":logging_ops_gen",
+ ":platform",
+ ":string_ops",
":util",
],
)
@@ -3738,6 +3763,19 @@
],
)
+cc_library(
+ name = "session_ref",
+ srcs = ["client/session_ref.cc"],
+ hdrs = ["client/session_ref.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:replay_log_proto_cc",
+ ],
+)
+
tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
@@ -3748,6 +3786,7 @@
":ndarray_tensor_bridge",
":numpy_lib",
":safe_ptr",
+ ":session_ref",
":test_ops_kernels",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -3760,7 +3799,6 @@
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_ref",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 7b3905f..80928ae 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -63,10 +63,8 @@
from __future__ import division
from __future__ import print_function
-import collections
from enum import Enum
-
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.pyct import anno
@@ -129,9 +127,8 @@
self.autograph_module = autograph_module
self.uncompiled_modules = uncompiled_modules
- # Required to output dependencies in discovery order, which should match
- # the reverse dependency order.
- self.dependency_cache = collections.OrderedDict()
+ self.conversion_order = []
+ self.dependency_cache = {}
self.additional_imports = set()
self.name_map = {}
@@ -177,6 +174,7 @@
self.name_map[o] = name
def add_to_cache(self, original_entity, converted_ast):
+ self.conversion_order.append(original_entity)
self.dependency_cache[original_entity] = converted_ast
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index ee2467e..1dc97d2 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -302,8 +302,9 @@
arg_types)
nodes = []
- for dep in reversed(tuple(program_ctx.dependency_cache.values())):
- nodes.extend(dep)
+ for dep in reversed(program_ctx.conversion_order):
+ nodes.extend(program_ctx.dependency_cache[dep])
+
compiled_module, compiled_src = compiler.ast_to_object(
nodes,
source_prefix=program_ctx.required_imports,
@@ -371,7 +372,7 @@
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
- compiler.ast_to_source(dep, indentation)
- for dep in reversed(tuple(program_ctx.dependency_cache.values())))
+ compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
+ for dep in reversed(program_ctx.conversion_order))
return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index 1433f9a..fca0eb6 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -27,6 +27,7 @@
from __future__ import print_function
import collections
+import weakref
from enum import Enum
# pylint:disable=g-bad-import-order
@@ -61,7 +62,10 @@
def freeze(self):
self.next = frozenset(self.next)
- self.prev = frozenset(self.prev)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ self.prev = weakref.WeakSet(self.prev)
def __repr__(self):
if isinstance(self.ast_node, gast.FunctionDef):
@@ -256,7 +260,7 @@
"""Resets the state of this factory."""
self.head = None
self.errors = set()
- self.node_index = collections.OrderedDict()
+ self.node_index = {}
# TODO(mdan): Too many primitives. Use classes.
self.leaves = set()
@@ -309,7 +313,10 @@
"""Grows the graph by adding a CFG node following the current leaves."""
if ast_node is self.node_index:
raise ValueError('%s added twice' % ast_node)
- node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ node = Node(next_=set(), prev=weakref.WeakSet(), ast_node=ast_node)
self.node_index[ast_node] = node
self.owners[node] = frozenset(self.active_stmts)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 9cb5991..086eda7 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -22,6 +22,7 @@
from __future__ import print_function
import copy
+import weakref
import gast
@@ -126,7 +127,10 @@
self.parent.mark_read(name)
def mark_param(self, name, owner):
- self.params[name] = owner
+ # Assumption: all AST nodes have the same life span. This lets us use
+ # a weak reference to mark the connection between a symbol node and the
+ # function node whose argument that symbol is.
+ self.params[name] = weakref.ref(owner)
def mark_creation(self, name, writes_create_symbol=False):
"""Mark a qualified name as created."""
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 48b442f..36b9e70 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -29,10 +29,11 @@
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
# TODO(aqj): Do we need this? Do other builtins fail in similar ways
# See b/114389775 for a related bug in pyct
# These symbols are legal in Python, but don't appear in the namespace.
-_special_symbols = {'range': range}
+_SPECIAL_SYMBOLS = {'range': range, 'print': print}
class LiveValueResolver(transformer.Base):
@@ -71,8 +72,10 @@
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
- elif node.id in _special_symbols:
- anno.setanno(node, 'live_val', _special_symbols[node.id])
+ elif node.id in _SPECIAL_SYMBOLS:
+ # Note: if the user redefined any of these symbols, then they would
+ # be visible in the namespace and we would never reach this branch.
+ anno.setanno(node, 'live_val', _SPECIAL_SYMBOLS[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
@@ -86,7 +89,8 @@
if has_single_def:
def_, = defs
- if def_.param_of is self.enclosing_entities[0]:
+ # Note: param_of is a weakref.
+ if def_.param_of and def_.param_of() is self.enclosing_entities[0]:
if node.id in self.entity_info.arg_values:
obj = self.entity_info.arg_values[node.id]
anno.setanno(node, 'live_val', obj)
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index ae0ad27..c963cfd 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -178,16 +178,30 @@
feed_function_for_partial_run: A callable for specifying tensor values to
feed when setting up a partial run, which takes a `tensor_type` type
object as input, and returns a list of Tensors.
+
+ Raises:
+ ValueError: If `tensor_type` has already been registered.
"""
for conversion_function in _REGISTERED_EXPANSIONS:
if issubclass(conversion_function[0], tensor_type):
- raise ValueError('%s has already been registered so ignore it.',
+ raise ValueError('%s has already been registered so ignore it.' %
tensor_type)
- return
+
_REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
feed_function_for_partial_run))
+def _is_attrs_instance(obj):
+ """Returns True if the given obj is an instance of attrs-decorated class."""
+ return getattr(obj.__class__, '__attrs_attrs__', None) is not None
+
+
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, '__attrs_attrs__')
+ return [getattr(obj, a.name) for a in attrs]
+
+
class _FetchMapper(object):
"""Definition of the interface provided by fetch mappers.
@@ -247,6 +261,8 @@
return _ListFetchMapper(fetch)
elif isinstance(fetch, collections.Mapping):
return _DictFetchMapper(fetch)
+ elif _is_attrs_instance(fetch):
+ return _AttrsFetchMapper(fetch)
else:
# Look for a handler in the registered expansions.
for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
@@ -398,6 +414,32 @@
return results
+class _AttrsFetchMapper(_FetchMapper):
+ """Fetch mapper for attrs decorated classes."""
+
+ def __init__(self, fetches):
+ """Creates a _AttrsFetchMapper.
+
+ Args:
+ fetches: An instance of an attrs decorated class.
+ """
+ values = _get_attrs_values(fetches)
+ self._fetch_type = type(fetches)
+ self._mappers = [
+ _FetchMapper.for_fetch(fetch) for fetch in values
+ ]
+ self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
+
+ def unique_fetches(self):
+ return self._unique_fetches
+
+ def build_results(self, values):
+ results = []
+ for m, vi in zip(self._mappers, self._value_indices):
+ results.append(m.build_results([values[j] for j in vi]))
+ return self._fetch_type(*results)
+
+
class _FetchHandler(object):
"""Handler for structured fetches.
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
new file mode 100644
index 0000000..b2300df
--- /dev/null
+++ b/tensorflow/python/client/session_ref.cc
@@ -0,0 +1,515 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/client/session_ref.h"
+
+#include <stdlib.h>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/named_tensor.pb.h"
+#include "tensorflow/core/protobuf/replay_log.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+// SessionRef blocks closing until all active calls complete or are cancelled.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+std::string SessionToHandle(Session* session) {
+ return strings::Printf("%llu", reinterpret_cast<uint64>(session));
+}
+
+// The Session interface has many methods of the form:
+//
+// X(a, b);
+// X(RunOptions, a, b);
+//
+// Not all sessions support the second case (with an empty RunOptions()).
+// We use this variable as a sentinel to dispatch to the correct call.
+RunOptions* kEmptyRunOptions() {
+ static RunOptions* options = new RunOptions();
+ return options;
+}
+
+} // namespace
+
+// Run the given session operation, recording start and end timestamps.
+// If the operation returns a bad status, return after flushing the current
+// log request. This should be run _after_ all request information has been
+// added to the current op.
+#define RUN_WITH_TIMESTAMP(OpName, ...) \
+ op.set_start_time_us(Env::Default()->NowMicros()); \
+ Status status = session->OpName(__VA_ARGS__); \
+ op.set_end_time_us(Env::Default()->NowMicros()); \
+ if (!status.ok()) { \
+ Flush(op).IgnoreError(); \
+ return status; \
+ }
+
+// Records requests (and optionally responses) performed against a session.
+// The resulting replay log can be used with the `tf_replay` tool to replicate
+// the operations against a simulated environment, without requiring the
+// original code or cluster setup.
+//
+// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
+class SessionLogger {
+ public:
+ SessionLogger() {
+ std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ TF_CHECK_OK(
+ Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
+ Env::Default()->DeleteFile(log_name).IgnoreError();
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+
+ log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
+ }
+
+ Status RecordCreateSession(Session* session) {
+ LOG(INFO) << "Capturing devices for session.";
+ ReplayOp op;
+ NewReplaySession* req = op.mutable_new_replay_session();
+
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+
+ req->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordRun(Session* session,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
+ target_node_names, outputs, nullptr);
+ }
+
+ Status RecordRun(Session* session, const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = run_options;
+
+ for (const auto& it : inputs) {
+ NamedTensorProto* feed = req->add_feed();
+ feed->set_name(it.first);
+ it.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ // Build an index from fetch tensor name to first index in
+ // output_tensor_names.
+ std::unordered_map<string, int> output_name_to_offset;
+ for (int i = 0; i < output_tensor_names.size(); ++i) {
+ const string& name = output_tensor_names[i];
+ if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
+ req->add_fetch(name);
+ }
+ }
+ for (const string& target : target_node_names) {
+ req->add_target(target);
+ }
+
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+ } else {
+ RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+ }
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_tensor_names[i]);
+ }
+
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordCreate(Session* session, const GraphDef& graph) {
+ return RecordCreate(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CreateRequest)
+ Status RecordCreate(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ CreateSessionRequest* req = op.mutable_create_session();
+ *req->mutable_graph_def() = graph;
+
+ CreateSessionResponse* resp = op.mutable_create_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Create, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Create, run_options, graph);
+ }
+ resp->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordExtend(Session* session, const GraphDef& graph) {
+ return RecordExtend(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
+ Status RecordExtend(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ ExtendSessionRequest* req = op.mutable_extend_session();
+ op.mutable_extend_session_response();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_graph_def() = graph;
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Extend, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Extend, run_options, graph);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordClose(Session* session) {
+ return RecordClose(session, *kEmptyRunOptions());
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CloseRequest)
+ Status RecordClose(Session* session, const RunOptions& run_options) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ CloseSessionRequest* req = op.mutable_close_session();
+ req->set_session_handle(SessionToHandle(session));
+ op.mutable_close_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Close);
+ } else {
+ RUN_WITH_TIMESTAMP(Close, run_options);
+ }
+ return Flush(op);
+ }
+
+ Status RecordListDevices(Session* session,
+ std::vector<DeviceAttributes>* response) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ ListDevicesRequest* req = op.mutable_list_devices();
+ ListDevicesResponse* resp = op.mutable_list_devices_response();
+ req->set_session_handle(SessionToHandle(session));
+ RUN_WITH_TIMESTAMP(ListDevices, response);
+
+ // TODO(power) -- local vs remote device distinction is lost here!
+ *resp->mutable_local_device() = {response->begin(), response->end()};
+ return Flush(op);
+ }
+
+ Status RecordPRunSetup(Session* session,
+ const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ PartialRunSetupRequest* req = op.mutable_partial_run_setup();
+ req->set_session_handle(SessionToHandle(session));
+ for (auto& input : input_names) {
+ req->add_feed(input);
+ }
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+ for (auto& target : target_nodes) {
+ req->add_target(target);
+ }
+ RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+ op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
+ return Flush(op);
+ }
+
+ Status RecordPRun(Session* session, const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+ req->set_session_handle(SessionToHandle(session));
+
+ // Mark this step as a partial run for replay.
+ req->set_partial_run_handle(handle);
+ for (auto& input : inputs) {
+ auto* feed = req->add_feed();
+ feed->set_name(input.first);
+ input.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+
+ RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_names[i]);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordMakeCallable(Session* session,
+ const CallableOptions& callable_options,
+ Session::CallableHandle* handle) {
+ ReplayOp op;
+ MakeCallableRequest* req = op.mutable_make_callable();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = callable_options;
+
+ RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
+
+ MakeCallableResponse* resp = op.mutable_make_callable_response();
+ resp->set_handle(*handle);
+
+ return Flush(op);
+ }
+
+ Status RecordRunCallable(Session* session, Session::CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunCallableRequest* req = op.mutable_run_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ for (auto& tensor : feed_tensors) {
+ tensor.AsProtoField(req->add_feed());
+ }
+ RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+
+ RunCallableResponse* resp = op.mutable_run_callable_response();
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+ for (const Tensor& tensor : *fetch_tensors) {
+ tensor.AsProtoTensorContent(resp->add_fetch());
+ }
+ return Flush(op);
+ }
+
+ Status RecordReleaseCallable(Session* session,
+ Session::CallableHandle handle) {
+ ReplayOp op;
+ ReleaseCallableRequest* req = op.mutable_release_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
+ return Flush(op);
+ }
+
+ private:
+ Status Flush(const ReplayOp& op) {
+ string buf;
+ op.SerializeToString(&buf);
+ TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
+
+ // Flushing the RecordWriter _does not_ flush the underlying file.
+ TF_RETURN_IF_ERROR(log_writer_->Flush());
+ return log_file_->Flush();
+ }
+
+ mutex log_mutex_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ std::unique_ptr<WritableFile> log_file_;
+};
+
+static SessionLogger* global_session_logger() {
+ static SessionLogger* logger = new SessionLogger();
+ return logger;
+}
+
+SessionRef::SessionRef(Session* session) : session_(session) {
+ if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
+ logger_ = global_session_logger();
+ logger_->RecordCreateSession(this->session_.get()).IgnoreError();
+ } else {
+ logger_ = nullptr;
+ }
+}
+
+SessionRef::~SessionRef() = default;
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+// If logging is active, log the start and end time of the operation along with
+// the request and response.
+#define LOG_AND_RUN_OPERATION(OpName, ...) \
+ TF_RETURN_IF_ERROR(CheckNotClosed()); \
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
+ if (!logger_) { \
+ return rc.session->OpName(__VA_ARGS__); \
+ } \
+ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, graph);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ LOG_AND_RUN_OPERATION(ListDevices, response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get(), run_options);
+ } else {
+ status = session_->Close(run_options);
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get());
+ } else {
+ status = session_->Close();
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/python/client/session_ref.h
similarity index 90%
rename from tensorflow/core/common_runtime/session_ref.h
rename to tensorflow/python/client/session_ref.h
index 9459e7e..b0fb12b 100644
--- a/tensorflow/core/common_runtime/session_ref.h
+++ b/tensorflow/python/client/session_ref.h
@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
#include <memory>
@@ -22,6 +22,8 @@
namespace tensorflow {
+class SessionLogger;
+
// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
//
// SessionRef blocks the return of Close() until all pending operations have
@@ -29,8 +31,8 @@
// subsequent operations on the SessionRef object will return errors::Cancelled.
class SessionRef : public Session {
public:
- SessionRef(Session* session) : session_(session) {}
- virtual ~SessionRef() {}
+ explicit SessionRef(Session* session);
+ ~SessionRef() override;
Status Create(const GraphDef& graph) override;
Status Extend(const GraphDef& graph) override;
@@ -78,9 +80,12 @@
uint64 run_count_ GUARDED_BY(run_lock_) = {0};
std::shared_ptr<Session> session_;
+ // Borrowed reference to global session logger.
+ SessionLogger* logger_;
+
Status CheckNotClosed();
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#endif // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 4afc639..f576435 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -61,6 +61,12 @@
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
+
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -300,6 +306,82 @@
self.assertEqual(None, res[2])
self.assertEqual(44.0, res[1])
+ def testFetchAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ val1 = np.array([1.2, 3.4, 5.6])
+ val2 = np.array([[1, 2], [4, 3]])
+ val3 = np.array([10, 20, 30])
+
+ t1 = constant_op.constant(val1)
+ t2 = constant_op.constant(val2)
+
+ sample = SampleAttr(t1, t2)
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val1, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ result = sess.run(sample, feed_dict={sample.field1: val3})
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val3, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ def testFetchNestedAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field0 = attr.ib()
+ field1 = attr.ib()
+
+ v1 = 10
+ v2 = 20
+ v3 = np.float32(1.2)
+ v4 = np.float32(3.4)
+ v5 = np.float64(100.001)
+ v6 = np.float64(-23.451)
+ arr1 = np.array([1.2, 6.7, 3.4])
+ arr2 = np.array([7, 11, 3])
+ sample = SampleAttr(
+ SampleAttr(
+ SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
+ SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
+ {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
+ 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
+
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertIsInstance(result.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field1, SampleAttr)
+ self.assertIsInstance(result.field0.field1.field0, np.ndarray)
+ self.assertAllEqual(arr1, result.field0.field1.field0)
+ self.assertIsInstance(result.field0.field1.field1, np.ndarray)
+ self.assertAllEqual(arr2, result.field0.field1.field1)
+ self.assertIsInstance(result.field1, dict)
+ self.assertIn('A', result.field1)
+ self.assertIn('B', result.field1)
+ self.assertIsInstance(result.field1['A'], SampleAttr)
+ self.assertAllEqual(
+ [v3, v4],
+ [result.field1['A'].field0, result.field1['A'].field1])
+ self.assertIsInstance(result.field1['B'], list)
+ self.assertEqual(1, len(result.field1['B']))
+ self.assertIsInstance(result.field1['B'][0], SampleAttr)
+ self.assertAllEqual(
+ [v5, v6],
+ [result.field1['B'][0].field0, result.field1['B'][0].field1])
+
def testFetchNestingEmptyOneLevel(self):
with session.Session() as sess:
a_val = 11.0
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 39a2922..ef7527d 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -463,7 +463,7 @@
}
// Override default py3 behavior of attempting to encode into Unicode.
-%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType {
+%typemap(out) std::string tensorflow::GetHandleShapeAndType {
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
@@ -782,7 +782,7 @@
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
-%unignore ResourceHandleShapeAndType;
+%unignore HandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index bcd4af2..dc0c10b 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,7 +20,6 @@
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -31,6 +30,7 @@
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/python/client/session_ref.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index 1e96ac5..c3f3829 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -588,7 +588,8 @@
alloc_tensor_set = set()
alloc_maxes[allocator] = AllocationMaximum(
timestamp=0, num_bytes=0, tensors=set())
- for time, num_bytes, name in alloc_list:
+ for time, num_bytes, name in sorted(
+ alloc_list, key=lambda allocation: allocation[0]):
total_bytes += num_bytes
if num_bytes < 0:
alloc_tensor_set.discard(name)
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index 281d7f2..032bbf7 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -134,7 +134,7 @@
ctf = tl.generate_chrome_trace_format()
self._validateTrace(ctf)
- def disabled_testAnalysisAndAllocations(self):
+ def testAnalysisAndAllocations(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
@@ -163,8 +163,6 @@
# At least num1 + num2, both float32s (4 bytes each)
self.assertGreaterEqual(cpu_max.num_bytes, 8)
self.assertGreater(cpu_max.timestamp, 0)
- self.assertTrue('num1' in cpu_max.tensors or 'num1/read' in cpu_max.tensors)
- self.assertTrue('num2' in cpu_max.tensors or 'num2/read' in cpu_max.tensors)
def testManyCPUs(self):
run_options = config_pb2.RunOptions(
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 157e699..5e8f5d6 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 18)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 21)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index ff49b69..91f21cb 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -741,7 +741,7 @@
debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
server_start_delay_sec=2.0, dump_to_filesystem=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a_init = constant_op.constant(42.0, name="a_init")
a = variables.Variable(a_init, name="a")
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index c1bc27d..a2686c6 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -34,6 +34,7 @@
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -146,6 +147,7 @@
"//tensorflow/python:clip_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 50a6ce6..d95e0fe 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -608,8 +608,9 @@
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- zeros=_zeros,
- ones=_ones)
+ zeros_fn=_zeros,
+ ones_fn=_ones,
+ graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5..3273174 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@
resource_variable_ops.ResourceVariable(2.0))
self.assertAllEqual(gradients_constants, gradients_variables)
+ def testUnknownShapes(self):
+ with context.graph_mode():
+ with backprop.GradientTape() as tape:
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ tape.watch(a)
+ b = a**3
+
+ db_da = tape.gradient(b, a)
+
+ with self.cached_session() as sess:
+ self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 4f1a85a..bcb1881 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -73,16 +73,36 @@
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if placeholder.dtype == dtypes_module.resource:
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
+ _copy_handle_data(value, placeholder)
+ return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes_module.resource or
+ target_t.dtype == dtypes_module.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- placeholder.graph._c_graph, placeholder._as_tf_output(),
- handle_data.SerializeToString())
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -91,12 +111,10 @@
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- placeholder._op._graph._c_graph, # pylint: disable=protected-access
- placeholder._as_tf_output(), # pylint: disable=protected-access
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
- return placeholder
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
@@ -435,6 +453,7 @@
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
+ self._func_graph_outputs = outputs
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -511,6 +530,8 @@
else:
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
+ for i, func_graph_output in enumerate(self._func_graph_outputs):
+ _copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -826,7 +847,12 @@
return nest.pack_sequence_as(args, function_inputs)
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+ python_func,
+ args,
+ kwds,
+ signature=None,
+ func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -841,6 +867,8 @@
`kwds` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
+ func_graph: Optional. An instance of FuncGraph. If provided, we will use
+ this graph else a new one is built and returned.
Returns:
A FuncGraph.
@@ -849,7 +877,9 @@
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(name)
+ if func_graph is None:
+ func_graph = FuncGraph(name)
+ assert isinstance(func_graph, FuncGraph)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 4a1bde3..e4513cc 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -48,6 +48,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@@ -438,10 +439,17 @@
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testVariableInLoopInFunction(self):
@@ -465,10 +473,17 @@
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
@@ -477,12 +492,34 @@
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = function.defun(f)
compiled()
+ def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+ with context.graph_mode():
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(1.0))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(2.0))
+
+ def f():
+ tl, value = list_ops.tensor_list_pop_back(
+ tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+ return tl
+
+ compiled = function.defun(f)
+ output_tensor_list = compiled()
+ _, value = list_ops.tensor_list_pop_back(
+ output_tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d1..5f5af4a 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@
from tensorflow.python import pywrap_tensorflow
-VSpace = collections.namedtuple(
- "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+ "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index a0f6be4..196e20e 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -889,12 +890,239 @@
return static_cast<tensorflow::DataType>(id);
}
+class PyTapeTensor {
+ public:
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ const tensorflow::TensorShape& shape)
+ : id_(id), dtype_(dtype), shape_(shape) {}
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ PyObject* shape)
+ : id_(id), dtype_(dtype), shape_(shape) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ PyTapeTensor(const PyTapeTensor& other) {
+ id_ = other.id_;
+ dtype_ = other.dtype_;
+ shape_ = other.shape_;
+ if (shape_.index() == 1) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ }
+
+ ~PyTapeTensor() {
+ if (shape_.index() == 1) {
+ Py_DECREF(absl::get<1>(shape_));
+ }
+ }
+ PyObject* GetShape() const;
+ PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+ tensorflow::int64 GetID() const { return id_; }
+
+ private:
+ tensorflow::int64 id_;
+ tensorflow::DataType dtype_;
+ absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+ if (zeros_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+ if (ones_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+ if (graph_shape_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_fn_);
+ Py_XDECREF(ones_fn_);
+ Py_XDECREF(graph_shape_fn_);
+
+ Py_DECREF(py_vspace_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+ PyObject* Zeros(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ PyObject* GraphShape(PyObject* tensor) const {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+ Py_DECREF(arg_list);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyBackwardFunction* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
+ }
+ }
+ PyObject* py_result = (*backward_function)(grads);
+ Py_DECREF(grads);
+ if (py_result == nullptr) {
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_fn_;
+ PyObject* ones_fn_;
+ PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+ if (shape_.index() == 0) {
+ auto& shape = absl::get<0>(shape_);
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+
+ return py_shape;
+ }
+
+ return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent),
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
@@ -1175,7 +1403,24 @@
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+ if (list == Py_None) return true;
+ tensorflow::Safe_PyObjectPtr seq(
+ PySequence_Fast(list, "expected a sequence"));
+ if (seq == nullptr) {
+ return false;
+ }
+
+ int len = PySequence_Size(list);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (item == Py_None) return true;
+ }
+
+ return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = PyEagerTensor_ID(tensor);
@@ -1183,16 +1428,16 @@
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+ return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
- return tensorflow::eager::TapeTensor{
- id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
+
+ if (ListContainsNone(shape_tuple)) {
+ return PyTapeTensor(id, dtype, tensor);
+ }
+
auto l = MakeIntList(shape_tuple);
Py_DECREF(shape_tuple);
// Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@
}
}
tensorflow::TensorShape shape(l);
- return tensorflow::eager::TapeTensor{id, dtype, shape};
+ return PyTapeTensor(id, dtype, shape);
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
- std::vector<tensorflow::eager::TapeTensor> output_info;
+ std::vector<PyTapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
@@ -1362,180 +1612,6 @@
}
}
-class PyVSpace
- : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
- Py_INCREF(py_vspace_);
- }
-
- tensorflow::Status Initialize() {
- num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
- if (num_elements_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
- if (aggregate_fn_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
- if (zeros_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- ones_ =
- PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
- if (ones_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- return tensorflow::Status::OK();
- }
-
- ~PyVSpace() override {
- Py_XDECREF(num_elements_);
- Py_XDECREF(aggregate_fn_);
- Py_XDECREF(zeros_);
- Py_XDECREF(ones_);
-
- Py_DECREF(py_vspace_);
- }
-
- tensorflow::int64 NumElements(PyObject* tensor) const final {
- if (EagerTensor_CheckExact(tensor)) {
- return PyEagerTensor_NumElements(tensor);
- }
- PyObject* arglist =
- Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
- PyObject* result = PyEval_CallObject(num_elements_, arglist);
- Py_DECREF(arglist);
- if (result == nullptr) {
- // The caller detects whether a python exception has been raised.
- return -1;
- }
- tensorflow::int64 r = MakeInt(result);
- Py_DECREF(result);
- return r;
- }
-
- PyObject* AggregateGradients(
- tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
- PyObject* list = PyList_New(gradient_tensors.size());
- for (int i = 0; i < gradient_tensors.size(); ++i) {
- // Note: stealing a reference to the gradient tensors.
- CHECK(gradient_tensors[i] != nullptr);
- CHECK(gradient_tensors[i] != Py_None);
- PyList_SET_ITEM(list, i,
- reinterpret_cast<PyObject*>(gradient_tensors[i]));
- }
- PyObject* arglist = Py_BuildValue("(O)", list);
- CHECK(arglist != nullptr);
- PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
- Py_DECREF(arglist);
- Py_DECREF(list);
- return result;
- }
-
- void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
- PyObject* Zeros(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(zeros_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return reinterpret_cast<PyObject*>(result);
- }
-
- PyObject* Ones(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(ones_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return result;
- }
-
- tensorflow::Status CallBackwardFunction(
- PyBackwardFunction* backward_function,
- tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
- std::vector<PyObject*>* result) const final {
- PyObject* grads = PyTuple_New(output_gradients.size());
- for (int i = 0; i < output_gradients.size(); ++i) {
- if (output_gradients[i] == nullptr) {
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(grads, i, Py_None);
- } else {
- PyTuple_SET_ITEM(grads, i,
- reinterpret_cast<PyObject*>(output_gradients[i]));
- }
- }
- PyObject* py_result = (*backward_function)(grads);
- Py_DECREF(grads);
- if (py_result == nullptr) {
- return tensorflow::errors::Internal("gradient function threw exceptions");
- }
- result->clear();
- PyObject* seq =
- PySequence_Fast(py_result, "expected a sequence of gradients");
- if (seq == nullptr) {
- return tensorflow::errors::InvalidArgument(
- "gradient function did not return a list");
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- VLOG(1) << "Gradient length is " << len;
- result->reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
- if (item == Py_None) {
- result->push_back(nullptr);
- } else {
- Py_INCREF(item);
- result->push_back(item);
- }
- }
- Py_DECREF(seq);
- Py_DECREF(py_result);
- return tensorflow::Status::OK();
- }
-
- void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
- PyObject* py_vspace_;
-
- PyObject* num_elements_;
- PyObject* aggregate_fn_;
- PyObject* zeros_;
- PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
- if (py_vspace != nullptr) {
- delete py_vspace;
- }
-
- py_vspace = new PyVSpace(e);
- auto status = py_vspace->Initialize();
- if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- delete py_vspace;
- return nullptr;
- }
-
- Py_RETURN_NONE;
-}
-
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index bfcc019..7f23499 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -197,6 +197,7 @@
srcs = ["canned/boosted_trees.py"],
srcs_version = "PY2AND3",
deps = [
+ ":boosted_trees_utils",
":estimator",
":head",
":model_fn",
@@ -224,6 +225,35 @@
)
py_library(
+ name = "boosted_trees_utils",
+ srcs = ["canned/boosted_trees_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":head",
+ ":model_fn",
+ "//tensorflow:tensorflow_py_no_contrib",
+ ],
+)
+
+py_test(
+ name = "boosted_trees_utils_test",
+ size = "medium",
+ srcs = ["canned/boosted_trees_utils_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ ":boosted_trees",
+ ":inputs",
+ "//tensorflow:tensorflow_py_no_contrib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["canned/dnn.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 19f1801..756d32d 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -22,7 +22,8 @@
import functools
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import boosted_trees_utils
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import dtypes
@@ -36,6 +37,7 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
@@ -197,8 +199,7 @@
cached_features = [
_local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
- name='cached_feature_{}'.format(i))
- for i in range(num_features)
+ name='cached_feature_{}'.format(i)) for i in range(num_features)
]
are_features_cached = _local_variable(False, name='are_features_cached')
@@ -228,8 +229,7 @@
return cached, cache_flip_op
input_feature_list, cache_flip_op = control_flow_ops.cond(
- are_features_cached,
- lambda: (cached_features, control_flow_ops.no_op()),
+ are_features_cached, lambda: (cached_features, control_flow_ops.no_op()),
cache_features_and_return)
return input_feature_list, cache_flip_op
@@ -263,8 +263,8 @@
elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
empty_key = ''
else:
- raise ValueError('Unsupported example_id_feature dtype %s.' %
- example_ids.dtype)
+ raise ValueError(
+ 'Unsupported example_id_feature dtype %s.' % example_ids.dtype)
# Cache holds latest <tree_id, node_id, logits> for each example.
# tree_id and node_id are both int32 but logits is a float32.
# To reduce the overhead, we store all of them together as float32 and
@@ -273,8 +273,8 @@
empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
self._example_ids = ops.convert_to_tensor(example_ids)
if self._example_ids.shape.ndims not in (None, 1):
- raise ValueError('example_id should have rank 1, but got %s' %
- self._example_ids)
+ raise ValueError(
+ 'example_id should have rank 1, but got %s' % self._example_ids)
self._logits_dimension = logits_dimension
def lookup(self):
@@ -334,7 +334,7 @@
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
self._node_ids = _local_variable(
- _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -422,9 +422,13 @@
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
- if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
- and tree_hparams.tree_complexity <= 0):
- raise ValueError('For pruning, tree_complexity must be positive.')
+ if tree_hparams.tree_complexity > 0:
+ if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError(
+ 'Tree complexity have no effect unless pruning mode is chosen.')
+ else:
+ if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError('For pruning, tree_complexity must be positive.')
# pylint: enable=protected-access
@abc.abstractmethod
@@ -719,7 +723,7 @@
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
# Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
+ if mode != model_fn_lib.ModeKeys.TRAIN:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
logits = boosted_trees_ops.predict(
@@ -886,6 +890,7 @@
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
+
# Add an early stop hook.
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
@@ -927,8 +932,8 @@
label_vocabulary):
"""Creates a head for classifier and the closed form gradients/hessians."""
head = _create_classification_head(n_classes, weight_column, label_vocabulary)
- if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None
- and label_vocabulary is None):
+ if (n_classes == 2 and head.logits_dimension == 1 and
+ weight_column is None and label_vocabulary is None):
# Use the closed-form gradients/hessians for 2 class.
def _grad_and_hess_for_logloss(logits, labels):
"""A closed form gradient and hessian for logistic loss."""
@@ -961,8 +966,196 @@
# pylint: enable=protected-access
+def _bt_explanations_fn(features,
+ head,
+ sorted_feature_columns,
+ name='boosted_trees'):
+ """Gradient Boosted Trees predict with explanations model_fn.
+
+ Args:
+ features: dict of `Tensor`.
+ head: A `head_lib._Head` instance.
+ sorted_feature_columns: Sorted iterable of `feature_column._FeatureColumn`
+ model inputs.
+ name: Name used for the model.
+
+ Returns:
+ An `EstimatorSpec` instance.
+
+ Raises:
+ ValueError: mode or params are invalid, or features has the wrong type.
+ """
+ mode = model_fn_lib.ModeKeys.PREDICT
+ with ops.name_scope(name) as name:
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=None,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ estimator_spec.predictions[boosted_trees_utils._DEBUG_PROTO_KEY] = debug_op # pylint: disable=protected-access
+ return estimator_spec
+
+
+class _BoostedTreesBase(estimator.Estimator):
+ """Base class for boosted trees estimators.
+
+ This class is intended to keep tree-specific functions (E.g., methods for
+ feature importances and directional feature contributions) in one central
+ place.
+
+ It is not a valid (working) Estimator on its own and should only be used as a
+ base class.
+ """
+
+ def __init__(self, model_fn, model_dir, config, feature_columns, head,
+ center_bias, is_classification):
+ """Initializes a `_BoostedTreesBase` instance.
+
+ Args:
+ model_fn: model_fn: Model function. See base class for more detail.
+ model_dir: Directory to save model parameters, graph and etc. See base
+ class for more detail.
+ config: `estimator.RunConfig` configuration object.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`
+ head: A `head_lib._Head` instance.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+ is_classification: If the estimator is for classification.
+ """
+ super(_BoostedTreesBase, self).__init__(
+ model_fn=model_fn, model_dir=model_dir, config=config)
+ self._sorted_feature_columns = sorted(
+ feature_columns, key=lambda tc: tc.name)
+ self._head = head
+ self._n_features = _calculate_num_features(self._sorted_feature_columns)
+ self._center_bias = center_bias
+ self._is_classification = is_classification
+
+ def experimental_predict_with_explanations(self,
+ input_fn,
+ predict_keys=None,
+ hooks=None,
+ checkpoint_path=None):
+ """Computes model explainability outputs per example along with predictions.
+
+ Currently supports directional feature contributions (DFCs). For each
+ instance, DFCs indicate the aggregate contribution of each feature. See
+ https://arxiv.org/abs/1312.1121 and
+ http://blog.datadive.net/interpreting-random-forests/ for more details.
+ Args:
+ input_fn: A function that provides input data for predicting as
+ minibatches. See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A `tf.data.Dataset` object: Outputs of `Dataset`
+ object must be a tuple `(features, labels)` with same constraints as
+ below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor`
+ or a dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+ predict_keys: list of `str`, name of the keys to predict. It is used if
+ the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+ `predict_keys` is used then rest of the predictions will be filtered
+ from the dictionary, with the exception of 'bias' and 'dfc', which will
+ always be in the dictionary. If `None`, returns all keys in prediction
+ dict, as well as two new keys 'dfc' and 'bias'.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the prediction call.
+ checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
+ latest checkpoint in `model_dir` is used. If there are no checkpoints
+ in `model_dir`, prediction is run with newly initialized `Variables`
+ instead of ones restored from checkpoint.
+
+ Yields:
+ Evaluated values of `predictions` tensors. The `predictions` tensors will
+ contain at least two keys 'dfc' and 'bias' for model explanations. The
+ `dfc` value corresponds to the contribution of each feature to the overall
+ prediction for this instance (positive indicating that the feature makes
+ it more likely to select class 1 and negative less likely). The 'bias'
+ value will be the same across all the instances, corresponding to the
+ probability (classification) or prediction (regression) of the training
+ data distribution.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ if not self._center_bias:
+ raise ValueError('center_bias must be enabled during estimator '
+ 'instantiation when using '
+ 'experimental_predict_with_explanations.')
+ # pylint: disable=protected-access
+ if not self._is_classification:
+ identity_inverse_link_fn = self._head._inverse_link_fn in (None,
+ tf_identity)
+ # pylint:enable=protected-access
+ if not identity_inverse_link_fn:
+ raise ValueError(
+ 'For now only identity inverse_link_fn in regression_head is '
+ 'supported for experimental_predict_with_explanations.')
+
+ # pylint:disable=unused-argument
+ def new_model_fn(features, labels, mode):
+ return _bt_explanations_fn(features, self._head,
+ self._sorted_feature_columns)
+
+ # pylint:enable=unused-argument
+ est = estimator.Estimator(
+ model_fn=new_model_fn,
+ model_dir=self.model_dir,
+ config=self.config,
+ warm_start_from=self._warm_start_settings)
+ # Make sure bias and dfc will be in prediction dict.
+ user_supplied_predict_keys = predict_keys is not None
+ if user_supplied_predict_keys:
+ predict_keys = set(predict_keys)
+ predict_keys.add(boosted_trees_utils._DEBUG_PROTO_KEY)
+ predictions = est.predict(
+ input_fn,
+ predict_keys=predict_keys,
+ hooks=hooks,
+ checkpoint_path=checkpoint_path,
+ yield_single_examples=True)
+ for pred in predictions:
+ bias, dfcs = boosted_trees_utils._parse_explanations_from_prediction(
+ pred[boosted_trees_utils._DEBUG_PROTO_KEY], self._n_features,
+ self._is_classification)
+ pred['bias'] = bias
+ pred['dfc'] = dfcs
+ # Don't need to expose serialized proto to end user.
+ del pred[boosted_trees_utils._DEBUG_PROTO_KEY]
+ yield pred
+
+
+# pylint: disable=protected-access
@estimator_export('estimator.BoostedTreesClassifier')
-class BoostedTreesClassifier(estimator.Estimator):
+class BoostedTreesClassifier(_BoostedTreesBase):
"""A Classifier for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1082,14 +1275,13 @@
n_classes = 2
head, closed_form = _create_classification_head_and_closed_form(
n_classes, weight_column, label_vocabulary=label_vocabulary)
-
# HParams for the model.
tree_hparams = _TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
+ return _bt_model_fn(
features,
labels,
mode,
@@ -1101,11 +1293,17 @@
closed_form_grad_and_hess_fn=closed_form)
super(BoostedTreesClassifier, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=True)
@estimator_export('estimator.BoostedTreesRegressor')
-class BoostedTreesRegressor(estimator.Estimator):
+class BoostedTreesRegressor(_BoostedTreesBase):
"""A Regressor for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1223,9 +1421,17 @@
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
- features, labels, mode, head, feature_columns, tree_hparams,
- n_batches_per_layer, config)
+ return _bt_model_fn(features, labels, mode, head, feature_columns,
+ tree_hparams, n_batches_per_layer, config)
super(BoostedTreesRegressor, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=False)
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 6e28c72..d4cb3e2 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -564,6 +564,175 @@
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testTreeComplexityIsSetCorrectly(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ num_steps = 10
+ # Tree complexity is set but no pruning.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ tree_complexity=1e-3)
+ with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+ est.train(input_fn, steps=num_steps)
+
+ # Pruning but no tree complexity.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre')
+ with self.assertRaisesRegexp(ValueError,
+ 'tree_complexity must be positive'):
+ est.train(input_fn, steps=num_steps)
+
+ # All is good.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=1e-3)
+ est.train(input_fn, steps=num_steps)
+
+
+class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
+ """Test debug/model explainability outputs for individual predictions.
+
+ Includes directional feature contributions (DFC).
+ """
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testBinaryClassifierThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=3, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([0.4] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: 0.19650601422250574,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: 0.16057487356133376,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: -0.10832468554550384,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == probabilities.
+ expected_probabilities = [
+ 0.23965894, 0.62344426, 0.58751315, 0.23965894, 0.31861359
+ ]
+ probabilities = [
+ sum(dfc.values()) + bias for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_probabilities, probabilities)
+
+ # When user doesn't include bias or dfc in predict_keys, make sure to still
+ # include dfc and bias.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['probabilities'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('probabilities' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+ def testRegressorThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils.py b/tensorflow/python/estimator/canned/boosted_trees_utils.py
new file mode 100644
index 0000000..85efc23
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debug and model explainability logic for boosted trees."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+
+# For directional feature contributions.
+_DEBUG_PROTO_KEY = '_serialized_debug_outputs_proto'
+_BIAS_ID = 0
+
+
+def _parse_debug_proto_string(example_proto_serialized):
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example_proto_serialized)
+ feature_ids = example_debug_outputs.feature_ids
+ logits_path = example_debug_outputs.logits_path
+ return feature_ids, logits_path
+
+
+def _compute_directional_feature_contributions(example_feature_ids,
+ example_logits_paths, activation,
+ num_bucketized_features):
+ """Directional feature contributions and bias, per example."""
+ # Initialize contributions to 0.
+ dfcs = {k: 0 for k in range(num_bucketized_features)}
+
+ # Traverse tree subtracting child prediction from parent prediction and
+ # associating change with feature id used to split.
+ predictions = np.array(activation(example_logits_paths))
+ delta_pred = predictions[_BIAS_ID + 1:] - predictions[:-1]
+ # Group by feature id, then sum delta_pred.
+ contribs = np.bincount(
+ example_feature_ids,
+ weights=delta_pred,
+ minlength=num_bucketized_features)
+ for f, dfc in zip(range(num_bucketized_features), contribs):
+ dfcs[f] = dfc
+ return predictions[_BIAS_ID], dfcs
+
+
+def _identity(logits):
+ return logits
+
+
+def _sigmoid(logits):
+ # TODO(crawles): Change to softmax once multiclass support is available.
+ return 1 / (1 + np.exp(-np.array(logits)))
+
+
+def _parse_explanations_from_prediction(serialized_debug_proto,
+ n_features,
+ classification=False):
+ """Parse serialized explanability proto, compute dfc, and return bias, dfc."""
+ feature_ids, logits_path = _parse_debug_proto_string(serialized_debug_proto)
+ if classification:
+ activation = _sigmoid
+ else:
+ activation = _identity
+ bias, dfcs = _compute_directional_feature_contributions(
+ feature_ids, logits_path, activation, n_features)
+ # TODO(crawles): Prediction path and leaf IDs.
+ return bias, dfcs
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils_test.py b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
new file mode 100644
index 0000000..506d4ea
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator.canned import boosted_trees_utils
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BoostedTreesDFCTest(test_util.TensorFlowTestCase):
+ """Test directional feature contributions (DFC) helper functions. """
+
+ def testDirectionalFeatureContributionsCompute(self):
+ """Tests logic to compute DFCs given feature ids and logits paths."""
+ num_bucketized_features = 3 # Includes one unused feature.
+ examples_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
+ e1_feature_ids, e2_feature_ids = examples_feature_ids
+
+ # DFCs are computed by traversing the prediction path and subtracting each
+ # child prediction from its parent prediction and associating the change in
+ # prediction with the respective feature id used for the split.
+ # For each activation function, f, (currently identity or sigmoid), DFCs are
+ # calculated for the two examples as:
+ # example 1:
+ # feature_0 = (f(1.114) - f(1.214)) + (f(6.114) - f(1.114))
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.214) - f(0.114))
+ # example 2:
+ # feature_0 = f(-5.486) - f(1.514)
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.514) - f(0.114))
+ # where bias_pred is = f(0) or f(0.21), with center_bias = {True, False},
+ # respectively.
+ # Keys are center_bias.
+ expected_dfcs_identity = {
+ False: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.214
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.514
+ }),
+ True: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.0039999999999998
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.3039999999999998
+ })
+ }
+ expected_dfcs_sigmoid = {
+ False: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2710059376234506
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.319653250251275
+ }),
+ True: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2186980280491253
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.26734534067694971
+ })
+ }
+ # pylint: disable=protected-access
+ for f, expected_dfcs in zip(
+ (boosted_trees_utils._identity, boosted_trees_utils._sigmoid),
+ (expected_dfcs_identity, expected_dfcs_sigmoid)):
+ for center_bias in [False, True]:
+ # If not center_bias, the bias after activation is 0.
+ if center_bias:
+ bias_logit = 0.21 # Root node of tree_0.
+ else:
+ bias_logit = 0 # 0 is default value when there is no original_leaf.
+ f_bias = f(bias_logit)
+
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ examples_logits_paths = ((bias_logit, 0.114, 1.214, 1.114, 6.114),
+ (bias_logit, 0.114, 1.514, -5.486))
+ e1_logits_path, e2_logits_path = examples_logits_paths
+ e1_expected_dfcs, e2_expected_dfcs = expected_dfcs[center_bias]
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint:disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, f, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, f, num_bucketized_features)
+ # pylint:enable=line-too-long
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (e1_logits_path[-1], e2_logits_path[-1])
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [f(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ preds = [e1_pred, e2_pred]
+ self.assertAllClose(preds, expected_preds)
+ # pylint: enable=protected-access
+
+ def testDFCComputeComparedToExternalExample(self):
+ """Tests `compute_dfc` compared to external example (regression).
+
+ Example from http://blog.datadive.net/interpreting-random-forests.
+ """
+ # DIS:3, RM: 2, LSTAT:1, NOX:0
+ num_bucketized_features = 4
+ e1_feature_ids = (2, 1, 0)
+ e2_feature_ids = (2, 2, 2)
+ e3_feature_ids = (2, 2, 0)
+
+ bias_logit = 22.60 # Root node of tree_0.
+ activation = boosted_trees_utils._identity
+ f_bias = activation(bias_logit)
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ e1_logits_path = (bias_logit, 19.96, 14.91, 18.11)
+ e2_logits_path = (bias_logit, 37.42, 45.10, 45.90)
+ e3_logits_path = (bias_logit, 37.42, 32.30, 33.58)
+ e1_expected_dfcs = {0: 3.20, 1: -5.05, 2: -2.64, 3: 0}
+ e2_expected_dfcs = {0: 0, 1: 0, 2: 23.3, 3: 0}
+ e3_expected_dfcs = {0: 1.28, 1: 0, 2: 9.7, 3: 0}
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint: disable=protected-access
+ # pylint: disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Example 3.
+ e3_bias, e3_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e3_feature_ids, e3_logits_path, activation, num_bucketized_features)
+ # pylint: enable=line-too-long
+ self.assertAllClose(e3_bias, f_bias)
+ self.assertAllClose(e3_dfc, e3_expected_dfcs)
+ # pylint: enable=protected-access
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (18.11, 45.90, 33.58)
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [activation(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ e3_pred = e3_bias + sum(e3_dfc.values())
+ preds = [e1_pred, e2_pred, e3_pred]
+ self.assertAllClose(preds, expected_preds)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 3eed1ab..ed3219c 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -376,7 +376,7 @@
" } "
"} ", example)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_result = sess.run(
serving_input_receiver.features,
feed_dict={
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index a8aef3a..68b3170 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -762,13 +762,12 @@
if handle_data:
handle_data = handle_data.SerializeToString()
else:
- handle_data = c_api.GetResourceHandleShapeAndType(
- tensor.graph._c_graph, tensor._as_tf_output())
+ handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
+ tensor._as_tf_output())
if handle_data:
- c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
- ph._as_tf_output(),
- compat.as_bytes(handle_data))
+ c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
+ compat.as_bytes(handle_data))
else:
ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 343f52f..8bb1779 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2532,8 +2532,8 @@
output._shape_val = output._c_api_shape()
# Set the resource handle data for compatibility with the Python shape
# inference code.
- serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
+ serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access
+ output._as_tf_output())
if serialized:
output._handle_data = (
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b739823..c302072 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,6 +24,7 @@
import contextlib
import gc
import itertools
+import os
import math
import random
import re
@@ -868,6 +869,19 @@
yield
+class CapturedWrites(object):
+ """A utility class to load the captured writes made to a stream."""
+
+ def __init__(self, capture_location):
+ self.capture_location = capture_location
+
+ def contents(self):
+ """Get the captured writes as a single string."""
+ with open(self.capture_location) as tmp_file:
+ output_data = "".join(tmp_file.readlines())
+ return output_data
+
+
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
@@ -934,6 +948,52 @@
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
+ @contextlib.contextmanager
+ def captureWritesToStream(self, stream):
+ """A context manager that captures the writes to a given stream.
+
+ This context manager captures all writes to a given stream inside of a
+ `CapturedWrites` object. When this context manager is created, it yields
+ the `CapturedWrites` object. The captured contents can be accessed by
+ calling `.contents()` on the `CapturedWrites`.
+
+ For this function to work, the stream must have a file descriptor that
+ can be modified using `os.dup` and `os.dup2`, and the stream must support
+ a `.flush()` method. The default python sys.stdout and sys.stderr are
+ examples of this. Note that this does not work in Colab or Jupyter
+ notebooks, because those use alternate stdout streams.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ with self.captureWritesToStream(sys.stdout) as captured:
+ result = MyOperator(input).eval()
+ self.assertStartsWith(captured.contents(), "This was printed.")
+ ```
+
+ Args:
+ stream: The stream whose writes should be captured. This
+ stream must have a file descriptor, support writing via using that
+ file descriptor, and must have a `.flush()` method.
+
+ Yields:
+ A `CapturedWrites` object that contains all writes to the specified stream
+ made during this context.
+ """
+ stream.flush()
+ fd = stream.fileno()
+ tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+ tmp_file = open(tmp_file_path, "w")
+ orig_fd = os.dup(fd)
+ os.dup2(tmp_file.fileno(), fd)
+ try:
+ yield CapturedWrites(tmp_file_path)
+ finally:
+ tmp_file.close()
+ os.dup2(orig_fd, fd)
+
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index b521b14..4a72c4b 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -381,12 +381,11 @@
],
)
-py_test(
+cuda_py_test(
name = "embeddings_test",
size = "medium",
srcs = ["layers/embeddings_test.py"],
- srcs_version = "PY2AND3",
- deps = [
+ additional_deps = [
":keras",
"//tensorflow/python:client_testlib",
],
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index a8b6d55..c35cdb1 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -63,7 +63,8 @@
def wrapper(*args, **kwargs):
if hasattr(keras_applications, 'get_submodules_from_kwargs'):
kwargs['backend'] = backend
- kwargs['layers'] = layers
+ if 'layers' not in kwargs:
+ kwargs['layers'] = layers
kwargs['models'] = models
kwargs['utils'] = utils
return base_fun(*args, **kwargs)
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f..6dfbbf3 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
- self.seen += batch_size
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
+ self.seen += batch_size * num_steps
for k, v in logs.items():
if k in self.stateful_metrics:
@@ -448,10 +451,13 @@
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
if self.use_steps:
- self.seen += 1
+ self.seen += num_steps
else:
- self.seen += batch_size
+ self.seen += batch_size * num_steps
for k in self.params['metrics']:
if k in logs:
@@ -1068,7 +1074,7 @@
logs = logs or {}
batch_logs = {('batch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._total_batches_seen += 1
@@ -1092,7 +1098,7 @@
# batch number as Tensorboard summaries
logs = {('epoch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index dc464c0..154c219 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -383,27 +383,31 @@
"""
# Validate that arguments passed by the user to `compile` are supported by
# DistributionStrategy.
- if distribute and not isinstance(
- optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise NotImplementedError('Only TF native optimizers are supported with '
- 'DistributionStrategy.')
- if distribute and context.executing_eagerly():
- raise NotImplementedError('DistributionStrategy is not supported in '
- 'Eager mode.')
- if distribute and sample_weight_mode:
- raise NotImplementedError('sample_weight_mode is not supported with '
- 'DistributionStrategy.')
- if distribute and weighted_metrics:
- raise NotImplementedError('weighted_metrics is not supported with '
- 'DistributionStrategy.')
- if distribute and target_tensors:
- raise ValueError('target_tensors is not supported with '
- 'DistributionStrategy.')
+ if distribute:
+ if not isinstance(
+ optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
+ raise NotImplementedError(
+ 'optimizer must be an instance of '
+ 'tf.train.Optimizer, not a %s' % type(optimizer))
+ if context.executing_eagerly():
+ raise NotImplementedError('DistributionStrategy is not supported '
+ 'when eager execution is enabled.')
+ if sample_weight_mode:
+ raise NotImplementedError('sample_weight_mode is not supported with '
+ 'DistributionStrategy.')
+ if weighted_metrics:
+ raise NotImplementedError('weighted_metrics is not supported with '
+ 'DistributionStrategy.')
+ if target_tensors:
+ raise ValueError('target_tensors is not supported with '
+ 'DistributionStrategy.')
loss = loss or {}
if context.executing_eagerly() and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise ValueError('Only TF native optimizers are supported in Eager mode.')
+ raise ValueError(
+ 'optimizer must be an instance of tf.train.Optimizer, not '
+ 'a %s' % type(optimizer))
self.optimizer = optimizers.get(optimizer)
# We've disabled automatic dependency tracking for this method, but do want
@@ -422,8 +426,9 @@
# Set DistributionStrategy specific parameters.
self._distribution_strategy = distribute
+ # Reset the value of grouped_model
+ self._grouped_model = None
if self._distribution_strategy is not None:
- self._grouped_model = None
distributed_training_utils.configure_and_create_session(
self._distribution_strategy)
if not self.built:
@@ -445,7 +450,8 @@
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'Output "' + name +
+ '" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 53291c3..26c5ec4 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -20,6 +20,7 @@
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
@@ -292,11 +293,16 @@
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ if steps_per_epoch is None:
+ raise ValueError('steps_per_epoch should be specified in the fit call.')
+ steps_per_run_var = K.variable(
+ value=min(steps_per_epoch, current_strategy.steps_per_run),
+ dtype='int32',
+ name='steps_per_run_var')
+
with current_strategy.scope():
- # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
- # steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=current_strategy.steps_per_run,
+ step_fn, iterator, iterations=steps_per_run_var,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -308,14 +314,6 @@
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
-
- assert steps_per_epoch is not None
-
- # TODO(sourabhbajaj): Convert this into a proper validation function
- if callbacks:
- raise NotImplementedError(
- 'Callbacks are not supported with TPUStrategy right now.')
-
callbacks = cbks.configure_callbacks(
callbacks,
model,
@@ -326,17 +324,26 @@
steps_per_epoch=steps_per_epoch,
verbose=verbose)
# TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
- # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+
+ # Calculate the steps each time on the device.
+ steps_to_run = [current_strategy.steps_per_run] * (
+ steps_per_epoch // current_strategy.steps_per_run)
+ if steps_per_epoch % current_strategy.steps_per_run:
+ steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)
+
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
- for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
- # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
- # and batch_size
- batch_logs = {'batch': step_index, 'size': 1}
+ step_index = 0
+ prev_step_count = None
+ for step_count in steps_to_run:
+ batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
+ if prev_step_count is None or step_count != prev_step_count:
+ steps_per_run_var.load(step_count, K.get_session())
+ prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
@@ -349,6 +356,7 @@
batch_logs.update(outputs)
callbacks.on_batch_end(step_index, batch_logs)
+ step_index = step_index + step_count
if callbacks.model.stop_training:
break
@@ -742,8 +750,9 @@
for name, tensor in zip(model.output_names, model.outputs):
# TODO(priyag): This is a workaround as we do not know the batch dimension
# of the model's output at this point.
- tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
- initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ shape = tensor_shape.TensorShape(tensor.shape.dims)
+ shape.dims = [batch_dimension] + shape.dims[1:]
+ initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 939a7f2..fb71bf2 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -186,7 +186,7 @@
# make sure either x,y or x,y,sample_weights is provided
if (not isinstance(inputs.output_shapes, (list, tuple)) or
len(inputs.output_shapes) not in (2, 3)):
- raise ValueError('Please provide either inputs and targets'
+ raise ValueError('Please provide either inputs and targets '
'or inputs, targets, and sample_weights')
for step_index in range(steps_per_epoch):
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 4ab786a..a2385df 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -314,7 +314,9 @@
'cannot be negative value: ' + str(negative_slope))
self.support_masking = True
- self.max_value = K.cast_to_floatx(max_value)
+ if max_value is not None:
+ max_value = K.cast_to_floatx(max_value)
+ self.max_value = max_value
self.negative_slope = K.cast_to_floatx(negative_slope)
self.threshold = K.cast_to_floatx(threshold)
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index b020b6e..c41087b 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -67,6 +67,14 @@
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': 10},
input_shape=(2, 3, 4))
+ x = keras.backend.ones((3, 4))
+ # Test that we use `leaky_relu` when appropriate in graph mode.
+ self.assertTrue(
+ 'LeakyRelu' in keras.layers.ReLU(negative_slope=0.2)(x).name)
+ # Test that we use `relu` when appropriate in graph mode.
+ self.assertTrue('Relu' in keras.layers.ReLU()(x).name)
+ # Test that we use `relu6` when appropriate in graph mode.
+ self.assertTrue('Relu6' in keras.layers.ReLU(max_value=6)(x).name)
def test_relu_with_invalid_arg(self):
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 629a9ec..c6df5f2 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
@@ -117,12 +119,27 @@
@tf_utils.shape_type_conversion
def build(self, input_shape):
- self.embeddings = self.add_weight(
- shape=(self.input_dim, self.output_dim),
- initializer=self.embeddings_initializer,
- name='embeddings',
- regularizer=self.embeddings_regularizer,
- constraint=self.embeddings_constraint)
+ # Note: most sparse optimizers do not have GPU kernels defined. When
+ # building graphs, the placement algorithm is able to place variables on CPU
+ # since it knows all kernels using the variable only exist on CPU.
+ # When eager execution is enabled, the placement decision has to be made
+ # right now. Checking for the presence of GPUs to avoid complicating the
+ # TPU codepaths which can handle sparse optimizers.
+ if context.executing_eagerly() and context.context().num_gpus():
+ with ops.device('cpu:0'):
+ self.embeddings = self.add_weight(
+ shape=(self.input_dim, self.output_dim),
+ initializer=self.embeddings_initializer,
+ name='embeddings',
+ regularizer=self.embeddings_regularizer,
+ constraint=self.embeddings_constraint)
+ else:
+ self.embeddings = self.add_weight(
+ shape=(self.input_dim, self.output_dim),
+ initializer=self.embeddings_initializer,
+ name='embeddings',
+ regularizer=self.embeddings_regularizer,
+ constraint=self.embeddings_constraint)
self.built = True
def compute_mask(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index cab176e..2e42e40 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -21,9 +21,11 @@
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
class EmbeddingTest(test.TestCase):
@@ -78,6 +80,17 @@
outputs = keras.backend.eval(layer(inputs))
self.assertAllClose(outputs, [[[1, 1], [2, 2], [1, 1]]])
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_eager_gpu_cpu(self):
+ l = keras.layers.Embedding(output_dim=2, input_dim=2)
+ l.build((None, 2))
+ inputs = keras.backend.constant([[0, 1, 0]], dtype='int32')
+ with backprop.GradientTape() as tape:
+ output = l(inputs)
+ gs = tape.gradient(output, l.weights)
+ opt = adagrad.AdagradOptimizer(0.1)
+ opt.apply_gradients(zip(gs, l.weights))
+ self.assertAllEqual(len(gs), 1)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index c7e9499..d6016ed 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -48,7 +48,7 @@
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(hidden_dim,
input_shape=(input_dim,)))
@@ -78,7 +78,7 @@
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
input_a = keras.Input((input_dim_a,))
input_b = keras.Input((input_dim_b,))
a = keras.layers.Dense(hidden_dim)(input_a)
@@ -105,7 +105,7 @@
if not check_if_compatible_devices(gpus=2):
return
- with self.test_session():
+ with self.cached_session():
input_shape = (1000, 10)
model = keras.models.Sequential()
model.add(keras.layers.Dense(10,
@@ -144,7 +144,7 @@
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
input_shape = (num_samples,) + shape
x_train = np.random.randint(0, 255, input_shape)
y_train = np.random.randint(0, num_classes, (input_shape[0],))
diff --git a/tensorflow/python/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/wrappers/scikit_learn_test.py
index c322efd..f904290 100644
--- a/tensorflow/python/keras/wrappers/scikit_learn_test.py
+++ b/tensorflow/python/keras/wrappers/scikit_learn_test.py
@@ -102,7 +102,7 @@
class ScikitLearnAPIWrapperTest(test.TestCase):
def test_classify_build_fn(self):
- with self.test_session():
+ with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
build_fn=build_fn_clf,
hidden_dim=HIDDEN_DIM,
@@ -118,7 +118,7 @@
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
- with self.test_session():
+ with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
build_fn=ClassBuildFnClf(),
hidden_dim=HIDDEN_DIM,
@@ -134,7 +134,7 @@
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
- with self.test_session():
+ with self.cached_session():
clf = InheritClassBuildFnClf(
build_fn=None,
hidden_dim=HIDDEN_DIM,
@@ -144,7 +144,7 @@
assert_classification_works(clf)
def test_regression_build_fn(self):
- with self.test_session():
+ with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
build_fn=build_fn_reg,
hidden_dim=HIDDEN_DIM,
@@ -160,7 +160,7 @@
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)
- with self.test_session():
+ with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
build_fn=ClassBuildFnReg(),
hidden_dim=HIDDEN_DIM,
@@ -176,7 +176,7 @@
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)
- with self.test_session():
+ with self.cached_session():
reg = InheritClassBuildFnReg(
build_fn=None,
hidden_dim=HIDDEN_DIM,
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 100240a..17831fa 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -538,6 +538,21 @@
)
tf_py_test(
+ name = "logging_ops_logging_level_test",
+ size = "small",
+ srcs = ["logging_ops_logging_level_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:logging_ops",
+ ],
+ tags = [
+ "no_windows",
+ ],
+)
+
+tf_py_test(
name = "logging_ops_test",
size = "small",
srcs = ["logging_ops_test.py"],
@@ -961,6 +976,19 @@
)
tf_py_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+tf_py_test(
name = "string_join_op_test",
size = "small",
srcs = ["string_join_op_test.py"],
@@ -3204,3 +3232,27 @@
grpc_enabled = True,
tags = ["no_gpu"], # TODO(b/111656070)
)
+
+# TODO(b/116053459): Replace with cuda_py_test.
+tf_py_test(
+ name = "while_v2_test",
+ size = "medium",
+ srcs = ["while_v2_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients_impl",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:tf_optimizer",
+ "//tensorflow/python:while_v2",
+ ],
+ grpc_enabled = True,
+ tags = ["no_gpu"], # TODO(b/116053459)
+)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 573bb86..2fe8583 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1276,5 +1276,203 @@
self.assertAllEqual(y.eval(), [0, 1, 2, 3])
+@test_util.run_all_in_graph_and_eager_modes
+class SortedSearchTest(test_util.TensorFlowTestCase):
+
+ def testUpperBoundFloatHandCoded(self):
+ cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+ arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+ dtype=np.float32)
+ result = np.searchsorted(cdf, arr, side="right")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundFloatRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+ arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundFloatUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.uniform(size=[batch_size, size_search_array]).astype(
+ np.float32),
+ axis=1)
+ arr = np.random.uniform(size=[batch_size, size_values]).astype(
+ np.float32) * size_search_array
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatHandCoded(self):
+ cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+ arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+ dtype=np.float32)
+ result = np.searchsorted(cdf, arr, side="left")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+ arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.uniform(size=[batch_size, size_search_array]).astype(
+ np.float32),
+ axis=1)
+ arr = np.random.uniform(size=[batch_size, size_values]).astype(
+ np.float32) * size_search_array
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntHandCoded(self):
+ cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+ arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+ result = np.searchsorted(cdf, arr, side="right")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+ axis=(d - 1))
+ arr = np.random.randint(
+ low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10,
+ size=[batch_size,
+ size_search_array]).astype(np.int64),
+ axis=1)
+ arr = np.random.randint(
+ low=0, high=10 * size_search_array, size=[batch_size,
+ size_values]).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntHandCoded(self):
+ cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+ arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+ result = np.searchsorted(cdf, arr, side="left")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+ axis=(d - 1))
+ arr = np.random.randint(
+ low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10,
+ size=[batch_size,
+ size_search_array]).astype(np.int64),
+ axis=1)
+ arr = np.random.randint(
+ low=0, high=10 * size_search_array, size=[batch_size,
+ size_values]).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ self.assertAllEqual(result, tf_result)
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index dee9610..3b28d44 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -928,6 +928,163 @@
class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
"""Tests feature contribs ops for model understanding."""
+ def testContribsForOnlyABiasNode(self):
+ """Tests case when, after training, only left with a bias node.
+
+ For example, this could happen if the final ensemble contains one tree that
+ got pruned up to the root.
+ """
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.72
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata: {
+ num_layers_grown: 0
+ }
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # All features are unused.
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29]
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ bias = 1.72 * 0.1 # Root node of tree_0.
+ expected_feature_ids = ((), ())
+ expected_logits_paths = ((bias,), (bias,))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
+ def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
+ """Tests case when, after training, first tree contains only a bias node."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.72
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 26
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 50
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf: {scalar: 5.5}
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.
+ tree_weights: 0.1
+ tree_metadata: {
+ num_layers_grown: 0
+ }
+ tree_metadata: {
+ num_layers_grown: 1
+ }
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29] # Unused feature.
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ expected_feature_ids = ((2, 0), (2,))
+ # bias = 1.72 * 1. # Root node of tree_0.
+ # example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
+ # example_1 : (bias, 0.1 * 7. + bias )
+ expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
def testContribsMultipleTree(self):
"""Tests that the contribs work when we have multiple trees."""
with self.cached_session() as session:
@@ -1018,11 +1175,14 @@
tree_weights: 0.2
tree_weights: 1.0
tree_metadata: {
- num_layers_grown: 1}
+ num_layers_grown: 1
+ }
tree_metadata: {
- num_layers_grown: 2}
+ num_layers_grown: 2
+ }
tree_metadata: {
- num_layers_grown: 1}
+ num_layers_grown: 1
+ }
""", tree_ensemble_config)
tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
new file mode 100644
index 0000000..252090b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class PrintV2LoggingLevelTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogInfo(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.info)
+ self.evaluate(print_op)
+ self.assertTrue("I" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogWarning(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.warning)
+ self.evaluate(print_op)
+ self.assertTrue("W" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.error)
+ self.evaluate(print_op)
+ self.assertTrue("E" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 82729b9..cf0beba 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -18,14 +18,23 @@
from __future__ import division
from __future__ import print_function
+import sys
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
class LoggingOpsTest(test.TestCase):
@@ -57,6 +66,302 @@
out.eval()
+class PrintV2Test(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensor(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorVarySummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=1)
+ self.evaluate(print_op)
+
+ expected = "[0 ... 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=2)
+ self.evaluate(print_op)
+
+ expected = "[0 1 ... 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=3)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=-1)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 3 4 5 6 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneVariable(self):
+ with self.test_session():
+ var = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(var)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoVariablesInStructWithAssignAdd(self):
+ with self.test_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ self.evaluate(plus_one)
+ print_op = logging_ops.print_v2(var_one, {"second": var_two})
+ self.evaluate(print_op)
+ expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoTensors(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, tensor * 10)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintPlaceholderGeneration(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
+ self.evaluate(print_op)
+ expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintNoTensors(self):
+ with self.test_session():
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
+ self.evaluate(print_op)
+ expected = "23 [23, 5] {'6': 12}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintFloatScalar(self):
+ with self.test_session():
+ tensor = ops.convert_to_tensor(434.43)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "434.43"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintStringScalar(self):
+ with self.test_session():
+ tensor = ops.convert_to_tensor("scalar")
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "scalar"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintComplexTensorStruct(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ small_tensor = constant_op.constant([0.3, 12.4, -16.1])
+ big_tensor = math_ops.mul(tensor, 10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ "first:", tensor, "middle:",
+ {"small": small_tensor, "Big": big_tensor}, 10,
+ [tensor * 2, tensor])
+ self.evaluate(print_op)
+ # Note that the keys in the dict will always be sorted,
+ # so 'Big' comes before 'small'
+ expected = ("first: [0 1 2 ... 7 8 9] "
+ "middle: {'Big': [0 10 20 ... 70 80 90], "
+ "'small': [0.3 12.4 -16.1]} "
+ "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensor(self):
+ with self.test_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(sparse)
+ self.evaluate(print_op)
+ expected = ("'SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensorInDataStruct(self):
+ with self.test_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2([sparse])
+ self.evaluate(print_op)
+ expected = ("['SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorStdout(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stdout) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=sys.stdout)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogInfo(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.info)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogWarning(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.warning)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.error)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidOutputStreamRaisesError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.assertRaises(ValueError):
+ print_op = logging_ops.print_v2(
+ tensor, output_stream="unknown")
+ self.evaluate(print_op)
+
+ def testPrintOpName(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ print_op = logging_ops.print_v2(tensor, name="print_name")
+ self.assertEqual(print_op.name, "print_name")
+
+ def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ formatted_string = string_ops.string_format("{}", tensor)
+ print_op = logging_ops.print_v2(formatted_string)
+ self.evaluate(print_op)
+ graph_ops = ops.get_default_graph().get_operations()
+ format_ops = [op for op in graph_ops if op.type == "StringFormat"]
+ # Should be only 1 format_op for graph mode.
+ self.assertEqual(len(format_ops), 1)
+
+ def testPrintOneTensorEagerOnOpCreate(self):
+ with self.test_session():
+ with context.eager_mode():
+ tensor = math_ops.range(10)
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed:
+ logging_ops.print_v2(tensor)
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintInDefunWithoutExplicitEvalOfPrint(self):
+ @function.defun
+ def f():
+ tensor = math_ops.range(10)
+ logging_ops.print_v2(tensor)
+ return tensor
+
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed_one:
+ x = f()
+ self.evaluate(x)
+ self.assertTrue((expected + "\n") in printed_one.contents())
+
+ # We execute the function again to make sure it doesn't only print on the
+ # first call.
+ with self.captureWritesToStream(sys.stderr) as printed_two:
+ y = f()
+ self.evaluate(y)
+ self.assertTrue((expected + "\n") in printed_two.contents())
+
+
class PrintGradientTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -65,6 +370,11 @@
inp_printed = logging_ops.Print(inp, [inp])
self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+ def testPrintString(self):
+ inp = constant_op.constant(2.0, shape=[100, 32])
+ inp_printed = logging_ops.Print(inp, ["hello"])
+ self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+
def testPrintGradient(self):
with self.cached_session():
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
new file mode 100644
index 0000000..afa71db
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -0,0 +1,384 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class StringFormatOpTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDim(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", [tensor])
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableScalar(self):
+ with self.test_session():
+ var = variables.Variable(3.34)
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "3.34"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableOneDim(self):
+ with self.test_session():
+ var = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatTwoVariablesWithAssignAdd(self):
+ with self.test_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}, {}", [var_one, var_two])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ self.evaluate(plus_one)
+ out = self.evaluate(format_output)
+ expected = "3.14, [0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimFloat(self):
+ with self.test_session():
+ tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimMatchesSummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimVarySummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=-1)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=1)
+ out = self.evaluate(format_output)
+ expected = "[0 ... 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = "[0 1 ... 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=10)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimAlmostSummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(5)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimLessThanSummarize(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(4), [2, 2])
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1]\n"
+ " [2 3]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDim(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimSummarizeTwo(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorThreeDim(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]\n"
+ "\n"
+ " [[100 101 102 ... 107 108 109]\n"
+ " [110 111 112 ... 117 118 119]\n"
+ " [120 121 122 ... 127 128 129]\n"
+ " ...\n [170 171 172 ... 177 178 179]\n"
+ " [180 181 182 ... 187 188 189]\n"
+ " [190 191 192 ... 197 198 199]]\n"
+ "\n"
+ " [[200 201 202 ... 207 208 209]\n"
+ " [210 211 212 ... 217 218 219]\n"
+ " [220 221 222 ... 227 228 229]\n"
+ " ...\n"
+ " [270 271 272 ... 277 278 279]\n"
+ " [280 281 282 ... 287 288 289]\n"
+ " [290 291 292 ... 297 298 299]]\n"
+ "\n"
+ " ...\n"
+ "\n"
+ " [[700 701 702 ... 707 708 709]\n"
+ " [710 711 712 ... 717 718 719]\n"
+ " [720 721 722 ... 727 728 729]\n"
+ " ...\n"
+ " [770 771 772 ... 777 778 779]\n"
+ " [780 781 782 ... 787 788 789]\n"
+ " [790 791 792 ... 797 798 799]]\n"
+ "\n"
+ " [[800 801 802 ... 807 808 809]\n"
+ " [810 811 812 ... 817 818 819]\n"
+ " [820 821 822 ... 827 828 829]\n"
+ " ...\n"
+ " [870 871 872 ... 877 878 879]\n"
+ " [880 881 882 ... 887 888 889]\n"
+ " [890 891 892 ... 897 898 899]]\n"
+ "\n"
+ " [[900 901 902 ... 907 908 909]\n"
+ " [910 911 912 ... 917 918 919]\n"
+ " [920 921 922 ... 927 928 929]\n"
+ " ...\n"
+ " [970 971 972 ... 977 978 979]\n"
+ " [980 981 982 ... 987 988 989]\n"
+ " [990 991 992 ... 997 998 999]]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefixAndSuffix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}, suffix",
+ tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplateSuffix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}, suffix", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatNoTensor(self):
+ with self.test_session():
+ format_output = string_ops.string_format("No tensor.", ())
+ out = self.evaluate(format_output)
+ expected = "No tensor."
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatMultiTensor(self):
+ with self.test_session():
+ tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
+ tensor_two = tensor_one * 10
+ format_output = string_ops.string_format("One: {},\nTwo: {}",
+ (tensor_one, tensor_two))
+ out = self.evaluate(format_output)
+ expected = ("One: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]],\n"
+ "Two: [[0 10 20 ... 70 80 90]\n"
+ " [100 110 120 ... 170 180 190]\n"
+ " [200 210 220 ... 270 280 290]\n"
+ " ...\n"
+ " [700 710 720 ... 770 780 790]\n"
+ " [800 810 820 ... 870 880 890]\n"
+ " [900 910 920 ... 970 980 990]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeOne(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=1)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 ... 9]\n"
+ " ...\n"
+ " [90 ... 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeTwo(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatPlaceholder(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: %t%", tensor,
+ placeholder="%t%")
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTensorCountMustMatchPlaceholderCount(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", tensor)
+ self.evaluate(format_output)
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", [tensor])
+ self.evaluate(format_output)
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"1 placeholder\(s\) in template does not match 2 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", (tensor, tensor))
+ self.evaluate(format_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
new file mode 100644
index 0000000..0c3b724
--- /dev/null
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -0,0 +1,276 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for while_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import while_v2
+from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
+from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
+from tensorflow.python.platform import test
+
+
+class WhileV2Test(test.TestCase, parameterized.TestCase):
+
+ def testSingleLoopVar(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testMultipleLoopVarsBasic(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y])
+ # ret = [x*y^2, y]
+
+ # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
+ grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
+ with self.test_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [45., 3.])
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testMultipleLoopVars(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ # y = x + y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w),
+ [x, y])
+ # ret = [y*x**2 + x*y**2, x*y + x + y]
+
+ gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
+ gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
+ gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
+ grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
+ grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
+ grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
+ with self.test_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [120., 23.])
+ self.assertSequenceEqual(sess.run(gradx_0), [39.])
+ self.assertSequenceEqual(sess.run(gradx_1), [4.])
+ self.assertSequenceEqual(sess.run(gradx_2), [43.])
+ self.assertSequenceEqual(sess.run(grady_0), [55.])
+ self.assertSequenceEqual(sess.run(grady_1), [6.])
+ self.assertSequenceEqual(sess.run(grady_2), [61.])
+
+ def testMultipleWhileLoops(self):
+ x = constant_op.constant(2.)
+ ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2
+ ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
+ grad = gradients_impl.gradients(ret2, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.test_session() as sess:
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testDoubleDerivative(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
+ grad = gradients_impl.gradients(ret, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testPruning(self):
+ x = constant_op.constant(1)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=x.dtype, element_shape=x.shape)
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5
+
+ def Body(x, tl):
+ return x + 1, list_ops.tensor_list_push_back(tl, x)
+
+ outputs = while_loop_v1(Cond, Body, [x, tensor_list])
+
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(outputs[0])
+
+ def GetOptimizedGraph():
+ mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+ return tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
+
+ stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
+ train_op.append(stack)
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
+
+ def testCaptureExternalTensorInCond(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(1.)
+ ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testCaptureExternalTensorInBody(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(3.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testLoopWithTensorListPushBack(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ tl = list_ops.tensor_list_push_back(tl, x)
+ tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testDuplicateAccumulator(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ # There is an accumulator in the loop already so we should not add
+ # another.
+ tl = list_ops.tensor_list_push_back(tl, x)
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+
+ for op in ops.get_default_graph().get_operations():
+ if op.type == "While":
+ while_op = op
+
+ body_graph = while_v2._get_body_graph(while_op)
+ # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
+ x_input_t = body_graph.inputs[1]
+ accumulator_count = len(
+ [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
+ self.assertEqual(accumulator_count, 1)
+
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ @parameterized.named_parameters(
+ ("UnknownShape", None),
+ ("PartiallyDefinedShape", [None, 2]),
+ ("FullyDefinedShape", [1, 2]),
+ )
+ def testTensorListOutputElementShape(self, shape):
+
+ def MatchShape(actual_tensor_shape):
+ # Compare the shapes, treating None dimensions as equal. We do not
+ # directly check actual_tensor_shape and tf.TensorShape(shape) for
+ # equality because tf.Dimension.__eq__ returns None if either dimension is
+ # None.
+ if shape is None:
+ self.assertIsNone(actual_tensor_shape.dims)
+ else:
+ self.assertListEqual(actual_tensor_shape.as_list(), shape)
+
+ def GetAccumulatorForInputAtIndex(while_op, idx):
+ body_graph = while_v2._get_body_graph(while_op)
+ y_input_t = body_graph.inputs[idx]
+ push_back_node = [c for c in y_input_t.consumers()
+ if c.type == "TensorListPushBack"][0]
+ output_idx = body_graph.outputs.index(push_back_node.outputs[0])
+ return while_op.outputs[output_idx]
+
+ x = constant_op.constant(2.)
+ y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
+
+ # Forward pass.
+ ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y])
+ while_op = ret[0].op
+ # Get the TensorList output of While op containing the accumulated values
+ # of y.
+ # while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators]
+ output = GetAccumulatorForInputAtIndex(while_op, 2)
+ _, val = list_ops.tensor_list_pop_back(output,
+ element_dtype=dtypes.float32)
+ MatchShape(val.shape)
+
+ # Gradient pass.
+ grad = gradients_impl.gradients(ret[1], y)
+ grad_while_op = grad[0].op
+ # Get the TensorList output of gradient While op containing the accumulated
+ # values of grad_y.
+ # grad_while_op.inputs:
+ # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args]
+ grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4)
+ _, val = list_ops.tensor_list_pop_back(grad_output,
+ element_dtype=dtypes.float32)
+ MatchShape(val.shape)
+
+
+def ScalarShape():
+ return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c8b8833..a7f57e9 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2787,4 +2787,65 @@
name=name)
+@tf_export("searchsorted")
+def searchsorted(sorted_sequence,
+ values,
+ side="left",
+ out_type=dtypes.int32,
+ name=None):
+ """Searches input tensor for values on the innermost dimension.
+
+ A 2-D example:
+
+ ```
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = searchsorted(sorted_sequence, values, side="left")
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+
+ result = searchsorted(sorted_sequence, values, side="right")
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+ ```
+
+ Args:
+ sorted_sequence: N-D `Tensor` containing a sorted sequence.
+ values: N-D `Tensor` containing the search values.
+ side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
+ upper_bound.
+ out_type: The output type (`int32` or `int64`). Default is `tf.int32`.
+ name: Optional name for the operation.
+
+ Returns:
+ An N-D `Tensor` the size of values containing the result of applying either
+ lower_bound or upper_bound (depending on side) to each value. The result
+ is not a global index to the entire `Tensor`, but the index in the last
+ dimension.
+
+ Raises:
+ ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
+ If the total size of values exceeds `2^31 - 1` elements.
+ If the first `N-1` dimensions of the two tensors don't match.
+ """
+ sequence_size = shape_internal(sorted_sequence)[-1]
+ values_size = shape_internal(values)[-1]
+ sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
+ values_2d = reshape(values, [-1, values_size])
+ if side == "right":
+ output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ elif side == "left":
+ output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ else:
+ raise ValueError("side must be either 'right' or 'left'. Saw: %s." % side)
+ return reshape(output, shape_internal(values))
+
+
quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 0e20fad..87f8bd8 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -610,9 +610,10 @@
"less-specific shape." %
(input_t.name, input_t.shape, n_shape))
else:
- if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
- raise TypeError("Type %s not supported" % type(var))
- if isinstance(var, ops.IndexedSlices):
+ if not isinstance(merge_var,
+ (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(merge_var))
+ if isinstance(merge_var, ops.IndexedSlices):
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = tensor_shape.TensorShape(None)
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 99d30b0..2ba1ea6 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -98,10 +98,13 @@
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a batch of three Beta distributions.
alpha = [1, 2, 3]
beta = [1, 2, 3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
dist.sample([4, 5]) # Shape [4, 5, 3]
@@ -117,7 +120,7 @@
# Create batch_shape=[2, 3] via parameter broadcast:
alpha = [[1.], [2]] # Shape [2, 1]
beta = [3., 4, 5] # Shape [3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
# alpha broadcast as: [[1., 1, 1,],
# [2, 2, 2]]
@@ -138,7 +141,7 @@
```python
alpha = tf.constant(1.0)
beta = tf.constant(2.0)
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 9104a1d..415249a 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -104,10 +104,13 @@
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 3]
@@ -129,7 +132,7 @@
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape: [2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
@@ -144,7 +147,7 @@
```python
alpha = tf.constant([1.0, 2.0, 3.0])
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
samples = dist.sample(5) # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 578e7b7..76d9806 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -601,7 +601,8 @@
return type(self)(**parameters)
def _batch_shape_tensor(self):
- raise NotImplementedError("batch_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
@@ -640,7 +641,8 @@
return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
- raise NotImplementedError("event_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "event_shape_tensor is not implemented: {}".format(type(self).__name__))
def event_shape_tensor(self, name="event_shape_tensor"):
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
@@ -701,7 +703,8 @@
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
- raise NotImplementedError("sample_n is not implemented")
+ raise NotImplementedError("sample_n is not implemented: {}".format(
+ type(self).__name__))
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
with self._name_scope(name, values=[sample_shape]):
@@ -733,15 +736,19 @@
return self._call_sample_n(sample_shape, seed, name)
def _log_prob(self, value):
- raise NotImplementedError("log_prob is not implemented")
+ raise NotImplementedError("log_prob is not implemented: {}".format(
+ type(self).__name__))
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_prob(self, value, name="log_prob"):
"""Log probability density/mass function.
@@ -757,15 +764,19 @@
return self._call_log_prob(value, name)
def _prob(self, value):
- raise NotImplementedError("prob is not implemented")
+ raise NotImplementedError("prob is not implemented: {}".format(
+ type(self).__name__))
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def prob(self, value, name="prob"):
"""Probability density/mass function.
@@ -781,15 +792,19 @@
return self._call_prob(value, name)
def _log_cdf(self, value):
- raise NotImplementedError("log_cdf is not implemented")
+ raise NotImplementedError("log_cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_cdf(self, value, name="log_cdf"):
"""Log cumulative distribution function.
@@ -815,15 +830,19 @@
return self._call_log_cdf(value, name)
def _cdf(self, value):
- raise NotImplementedError("cdf is not implemented")
+ raise NotImplementedError("cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def cdf(self, value, name="cdf"):
"""Cumulative distribution function.
@@ -845,15 +864,20 @@
return self._call_cdf(value, name)
def _log_survival_function(self, value):
- raise NotImplementedError("log_survival_function is not implemented")
+ raise NotImplementedError(
+ "log_survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_survival_function(value, **kwargs)
- except NotImplementedError:
- return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_survival_function(self, value, name="log_survival_function"):
"""Log survival function.
@@ -880,15 +904,19 @@
return self._call_log_survival_function(value, name)
def _survival_function(self, value):
- raise NotImplementedError("survival_function is not implemented")
+ raise NotImplementedError("survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._survival_function(value, **kwargs)
- except NotImplementedError:
- return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError as original_exception:
+ try:
+ return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError:
+ raise original_exception
def survival_function(self, value, name="survival_function"):
"""Survival function.
@@ -912,7 +940,8 @@
return self._call_survival_function(value, name)
def _entropy(self):
- raise NotImplementedError("entropy is not implemented")
+ raise NotImplementedError("entropy is not implemented: {}".format(
+ type(self).__name__))
def entropy(self, name="entropy"):
"""Shannon entropy in nats."""
@@ -920,7 +949,8 @@
return self._entropy()
def _mean(self):
- raise NotImplementedError("mean is not implemented")
+ raise NotImplementedError("mean is not implemented: {}".format(
+ type(self).__name__))
def mean(self, name="mean"):
"""Mean."""
@@ -928,7 +958,8 @@
return self._mean()
def _quantile(self, value):
- raise NotImplementedError("quantile is not implemented")
+ raise NotImplementedError("quantile is not implemented: {}".format(
+ type(self).__name__))
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
@@ -955,7 +986,8 @@
return self._call_quantile(value, name)
def _variance(self):
- raise NotImplementedError("variance is not implemented")
+ raise NotImplementedError("variance is not implemented: {}".format(
+ type(self).__name__))
def variance(self, name="variance"):
"""Variance.
@@ -979,11 +1011,15 @@
with self._name_scope(name):
try:
return self._variance()
- except NotImplementedError:
- return math_ops.square(self._stddev())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.square(self._stddev())
+ except NotImplementedError:
+ raise original_exception
def _stddev(self):
- raise NotImplementedError("stddev is not implemented")
+ raise NotImplementedError("stddev is not implemented: {}".format(
+ type(self).__name__))
def stddev(self, name="stddev"):
"""Standard deviation.
@@ -1008,11 +1044,15 @@
with self._name_scope(name):
try:
return self._stddev()
- except NotImplementedError:
- return math_ops.sqrt(self._variance())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.sqrt(self._variance())
+ except NotImplementedError:
+ raise original_exception
def _covariance(self):
- raise NotImplementedError("covariance is not implemented")
+ raise NotImplementedError("covariance is not implemented: {}".format(
+ type(self).__name__))
def covariance(self, name="covariance"):
"""Covariance.
@@ -1054,7 +1094,8 @@
return self._covariance()
def _mode(self):
- raise NotImplementedError("mode is not implemented")
+ raise NotImplementedError("mode is not implemented: {}".format(
+ type(self).__name__))
def mode(self, name="mode"):
"""Mode."""
@@ -1080,7 +1121,7 @@
where `F` denotes the support of the random variable `X ~ P`.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1111,7 +1152,7 @@
denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1123,7 +1164,7 @@
return self._kl_divergence(other)
def __str__(self):
- return ("tf.distributions.{type_name}("
+ return ("tfp.distributions.{type_name}("
"\"{self_name}\""
"{maybe_batch_shape}"
"{maybe_event_shape}"
@@ -1139,7 +1180,7 @@
dtype=self.dtype.name))
def __repr__(self):
- return ("<tf.distributions.{type_name} "
+ return ("<tfp.distributions.{type_name} "
"'{self_name}'"
" batch_shape={batch_shape}"
" event_shape={event_shape}"
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index b631f02..3293cda 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -100,8 +100,11 @@
#### Examples
```python
- dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
- dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
+ dist = tfd.Gamma(concentration=3.0, rate=2.0)
+ dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
Compute the gradients of samples w.r.t. the parameters:
@@ -109,7 +112,7 @@
```python
concentration = tf.constant(3.0)
rate = tf.constant(2.0)
- dist = tf.distributions.Gamma(concentration, rate)
+ dist = tfd.Gamma(concentration, rate)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index e3c6f3e..fdeb97b 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -127,8 +127,8 @@
where `F` denotes the support of the random variable `X ~ P`.
Args:
- ref: `tf.distributions.Distribution` instance.
- other: `tf.distributions.Distribution` instance.
+ ref: `tfd.Distribution` instance.
+ other: `tfd.Distribution` instance.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index d0a987b..2feaf80 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -71,15 +71,18 @@
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Normal distribution.
- dist = tf.distributions.Normal(loc=0., scale=3.)
+ dist = tfd.Normal(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
- dist = tf.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -94,7 +97,7 @@
```python
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
- dist = tf.distributions.Normal(loc=1., scale=[11, 22.])
+ dist = tfd.Normal(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e0cf6f8..e8d214b 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -91,8 +91,11 @@
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Student t distribution.
- single_dist = tf.distributions.StudentT(df=3)
+ single_dist = tfd.StudentT(df=3)
# Evaluate the pdf at 1, returning a scalar Tensor.
single_dist.prob(1.)
@@ -100,9 +103,7 @@
# Define a batch of two scalar valued Student t's.
# The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22.
- multi_dist = tf.distributions.StudentT(df=[2, 3],
- loc=[1, 2.],
- scale=[11, 22.])
+ multi_dist = tfd.StudentT(df=[2, 3], loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -117,7 +118,7 @@
```python
# Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales.
- dist = tf.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
+ dist = tfd.StudentT(df=2, loc=1, scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@@ -130,7 +131,7 @@
df = tf.constant(2.0)
loc = tf.constant(2.0)
scale = tf.constant(11.0)
- dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
+ dist = tfd.StudentT(df=df, loc=loc, scale=scale)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
@@ -138,7 +139,6 @@
```
"""
- # pylint: enable=line-too-long
def __init__(self,
df,
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index de260f3..325418d 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -29,7 +29,6 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -301,21 +300,21 @@
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
- Args:
- image: 4-D Tensor of shape `[batch, height, width, channels]` or
- 3-D Tensor of shape `[height, width, channels]`.
- flip_index: The dimension along which to flip the image.
- Vertical: 0, Horizontal: 1
- seed: A Python integer. Used to create a random seed. See
- `tf.set_random_seed`
- for behavior.
- scope_name: Name of the scope in which the ops are added.
- Returns:
- A tensor of the same type and shape as `image`.
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ `tf.set_random_seed`
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
- Raises:
- ValueError: if the shape of `image` not supported.
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
@@ -334,15 +333,16 @@
result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
+ batch_size = array_ops.shape(image)[0]
uniform_random = random_ops.random_uniform(
- [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ [batch_size], 0, 1.0, seed=seed
)
- mirror_cond = math_ops.less(uniform_random, .5)
- return array_ops.where(
- mirror_cond,
- image,
- functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ flips = math_ops.round(
+ array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])
)
+ flips = math_ops.cast(flips, image.dtype)
+ flipped_input = array_ops.reverse(image, [flip_index + 1])
+ return flips * flipped_input + (1 - flips) * image
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933..4c53f33 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@
from __future__ import division
from __future__ import print_function
+import pprint
+import random
+import sys
+
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +51,32 @@
# For users with Python 3 or Python 2.7
# with `from __future__ import print_function`, we could also allow lowercase.
# See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+ "tf.print returns a no-output operator that directly "
+ "prints the output. Outside of defuns or eager mode, "
+ "this operator will not be executed unless it is "
+ "directly specified in session.run or used as a "
+ "control dependency for other operators. This is "
+ "only a concern in graph mode. Below is an example "
+ "of how to ensure tf.print executes in graph mode:\n"
+ """```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print(tensor)
+ with tf.control_dependencies([print_op]):
+ out = tf.add(tensor, tensor)
+ sess.run(out)
+ ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+ `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
@@ -66,6 +102,228 @@
A `Tensor`. Has the same type and contents as `input_`.
"""
return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+ """Generate and return a string that does not appear in `x`."""
+ placeholder = default_placeholder
+ rng = random.Random(5)
+ while placeholder in x:
+ placeholder = placeholder + str(rng.randint(0, 9))
+ return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+ """Print the specified inputs.
+
+ Returns an operator that prints the specified inputs to a desired
+ output stream or logging level. The inputs may be dense or sparse Tensors,
+ primitive python objects, data structures that contain Tensors, and printable
+ python objects. Printed tensors will recursively show the first and last
+ `summarize` elements of each dimension.
+
+ With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+ operator will automatically execute, and users only need to call `tf.print`
+ without using the return value. When constructing graphs outside of a
+ `tf.contrib.eager.defun`, one must either include the returned op
+ in the input to `session.run`, or use the operator as a control dependency for
+ executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+ @compatibility(python2)
+ In python 2.7, make sure to import the following:
+ `from __future__ import print_function`
+ @end_compatibility
+
+ Example:
+ Single-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Multi-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Usage in a defun:
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def f():
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ return tensor
+
+ range_tensor = f()
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Usage when constructing graphs:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+ output_stream=sys.stdout)
+ with tf.control_dependencies([print_op]):
+ tripled_tensor = tensor * 3
+ sess.run(tripled_tensor)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Note: This op is only partially compatible with Jupyter notebooks and colabs.
+ Because it prints to the C++ standard out / standard error, this will go
+ in the notebook kernel's console output, not in the notebook cell output.
+
+ Args:
+ *inputs: Positional arguments that are the inputs to print. Inputs in the
+ printed output will be separated by spaces. Inputs may be python
+ primitives, tensors, data structures such as dicts and lists that
+ may contain tensors (with the data structures possibly nested in
+ arbitrary ways), and printable python objects.
+ output_stream: The output stream or logging level to print to. Defaults to
+ sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+ tf.logging.error are also supported.
+ summarize: The first and last `summarize` elements within each dimension are
+ recursively printed per Tensor. If None, then the first 3 and last 3
+ elements of each dimension are printed for each tensor. If set to -1, it
+ will print all elements of every tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A print operator that prints the specified inputs in the specified output
+ stream or logging level.
+
+ Raises:
+ ValueError: If an unsupported output stream is specified.
+ """
+ # Because we are using arbitrary-length positional arguments, python 2
+ # does not support explicitly specifying the keyword arguments in the
+ # function definition. So, we manually get the keyword arguments w/ default
+ # values here.
+ output_stream = kwargs.pop("output_stream", sys.stderr)
+ name = kwargs.pop("name", None)
+ summarize = kwargs.pop("summarize", 3)
+ if kwargs:
+ raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+ format_name = None
+ if name:
+ format_name = name + "_format"
+
+ # Match the C++ string constants representing the different output streams.
+ # Keep this updated!
+ output_stream_to_constant = {
+ sys.stdout: "stdout",
+ sys.stderr: "stderr",
+ tf_logging.INFO: "log(info)",
+ tf_logging.info: "log(info)",
+ tf_logging.WARN: "log(warning)",
+ tf_logging.warning: "log(warning)",
+ tf_logging.warn: "log(warning)",
+ tf_logging.ERROR: "log(error)",
+ tf_logging.error: "log(error)",
+ }
+
+ output_stream_string = output_stream_to_constant.get(output_stream)
+ if not output_stream_string:
+ raise ValueError(
+ "Unsupported output stream or logging level " +
+ str(output_stream) + ". Supported streams are sys.stdout, "
+ "sys.stderr, tf.logging.info, "
+ "tf.logging.warning, tf.logging.error")
+
+ # If we are only printing a single string scalar, there is no need to format
+ if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+ and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+ and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+ formatted_string = inputs[0]
+ # Otherwise, we construct an appropriate template for the tensors we are
+ # printing, and format the template using those tensors.
+ else:
+ # For each input to this print function, we extract any nested tensors,
+ # and construct an appropriate template to format representing the
+ # printed input.
+ templates = []
+ tensors = []
+ tensor_free_structure = nest.map_structure(
+ lambda x: "" if tensor_util.is_tensor(x) else x,
+ inputs)
+ tensor_free_template = " ".join(pprint.pformat(x)
+ for x in tensor_free_structure)
+ placeholder = _generate_placeholder_string(tensor_free_template)
+
+ for input_ in inputs:
+ placeholders = []
+ # Use the nest utilities to flatten & process any nested elements in this
+ # input. The placeholder for a tensor in the template should be the
+ # placeholder string, and the placeholder for a non-tensor can just be
+ # the printed value of the non-tensor itself.
+ for x in nest.flatten(input_):
+ # support sparse tensors
+ if isinstance(x, sparse_tensor.SparseTensor):
+ tensors.extend([x.indices, x.values, x.dense_shape])
+ placeholders.append(
+ "SparseTensor(indices={}, values={}, shape={})".format(
+ placeholder, placeholder, placeholder)
+ )
+ elif tensor_util.is_tensor(x):
+ tensors.append(x)
+ placeholders.append(placeholder)
+ else:
+ placeholders.append(x)
+
+ if isinstance(input_, six.string_types):
+ # If the current input to format/print is a normal string, that string
+ # can act as the template.
+ cur_template = input_
+ else:
+ # We pack the placeholders into a data structure that matches the
+ # input data structure format, then format that data structure
+ # into a string template.
+ #
+ # NOTE: We must use pprint.pformat here for building the template for
+ # unordered data structures such as `dict`, because `str` doesn't
+ # guarantee orderings, while pprint prints in sorted order. pprint
+ # will match the ordering of `nest.flatten`.
+ # This even works when nest.flatten reorders OrderedDicts, because
+ # pprint is printing *after* the OrderedDicts have been reordered.
+ cur_template = pprint.pformat(
+ nest.pack_sequence_as(input_, placeholders))
+ templates.append(cur_template)
+
+ # We join the templates for the various inputs into a single larger
+ # template. We also remove all quotes surrounding the placeholders, so that
+ # the formatted/printed output will not contain quotes around tensors.
+ # (example of where these quotes might appear: if we have added a
+ # placeholder string into a list, then pretty-formatted that list)
+ template = " ".join(templates)
+ template = template.replace("'" + placeholder + "'", placeholder)
+ formatted_string = string_ops.string_format(
+ inputs=tensors, template=template, placeholder=placeholder,
+ summarize=summarize,
+ name=format_name)
+
+ return gen_logging_ops.print_v2(formatted_string,
+ output_stream=output_stream_string,
+ name=name)
+# pylint: enable=g-doc-args
@ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/losses/util_test.py b/tensorflow/python/ops/losses/util_test.py
index 7fa7a41..df2e60e 100644
--- a/tensorflow/python/ops/losses/util_test.py
+++ b/tensorflow/python/ops/losses/util_test.py
@@ -28,7 +28,7 @@
def testGetRegularizationLoss(self):
# Empty regularization collection should evaluate to 0.0.
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0.0, util.get_regularization_loss().eval())
# Loss should sum.
@@ -36,14 +36,14 @@
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, util.get_regularization_loss().eval())
# Check scope capture mechanism.
with ops.name_scope('scope1'):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
index 015181a..07fc943 100644
--- a/tensorflow/python/ops/parallel_for/BUILD
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -123,6 +123,8 @@
"//third_party/py/numpy",
"//tensorflow/python:layers",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/ops/losses",
],
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 460de0a..1f026b3 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -42,6 +42,7 @@
[y_1, ..., y_n, x_1, ..., x_m].
"""
flat_inputs = nest.flatten(inputs)
+ output_tensor_shape = output.shape
output_shape = array_ops.shape(output)
output = array_ops.reshape(output, [-1])
@@ -65,6 +66,7 @@
new_shape = array_ops.concat(
[output_shape, array_ops.shape(out)[1:]], axis=0)
out = array_ops.reshape(out, new_shape)
+ out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
pfor_outputs[i] = out
return nest.pack_sequence_as(inputs, pfor_outputs)
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index 628c676..5467f55 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -32,6 +32,8 @@
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import layers as tf_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -355,6 +357,30 @@
self.run_and_assert_equal(answer, jacobian_pfor)
self.run_and_assert_equal(answer, jacobian_while)
+ def test_jacobian_scan_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ elems = random_ops.random_uniform([6])
+ # Shape y: [6, 3, 4]
+ y = functional_ops.scan(lambda a, e: a + e, elems, initializer=x)
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [6, 3, 4, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
+ def test_jacobian_while_loop_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ _, y = tf_control_flow_ops.while_loop(lambda i, a: i > 5.,
+ lambda i, a: (i + 1, a + i),
+ (constant_op.constant(0.), x))
+ # Shape y: [2, 3]
+ y = y[:2, :3]
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [2, 3, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
def test_jacobian_unknown_shape(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=[None, None])
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 55c2eb5..4a126e9 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -48,14 +48,14 @@
assert ops._USE_C_SHAPES # pylint: disable=protected-access
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
- handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+ handle_data = pywrap_tensorflow.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
-def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
"""Creates a variable handle with information to do shape inference."""
container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None:
@@ -397,61 +397,33 @@
# When in eager mode use a uid for the shared_name, to prevent
# accidental sharing.
shared_name = "%s_%d" % (handle_name, ops.uid())
- if init_from_fn:
- # Use attr_scope and device(None) to simulate the behavior of
- # colocate_with when the variable we want to colocate with doesn't
- # yet exist.
- if self._in_graph_mode:
- attr = attr_value_pb2.AttrValue(
- list=attr_value_pb2.AttrValue.ListValue(
- s=[compat.as_bytes("loc:@%s" % handle_name)]))
- with ops.get_default_graph()._attr_scope({"_class": attr}):
- with ops.name_scope("Initializer"), ops.device(None):
- initial_value = ops.convert_to_tensor(
- initial_value(), name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
- else:
- initial_value = initial_value()
- with ops.name_scope("Initializer"):
- initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=False)
- self._shape = initial_value.get_shape()
- # pylint: enable=protected-access
-
- # Or get the initial value from a Tensor or Python object.
- else:
- with ops.name_scope("Initializer"):
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ s=[compat.as_bytes("loc:@%s" % handle_name)]))
+ with ops.get_default_graph()._attr_scope({"_class": attr}):
+ with ops.name_scope("Initializer"), ops.device(None):
initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- # pylint: disable=protected-access
- if (self._in_graph_mode and initial_value is not None and
- initial_value.op._get_control_flow_context() is not None):
- raise ValueError(
- "Initializer for variable %s is from inside a control-flow "
- "construct, such as a loop or conditional. When creating a "
- "variable inside a loop or conditional, use a lambda as the "
- "initializer." % name)
- # pylint: enable=protected-access
- self._handle = _eager_safe_variable_handle(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ self._handle = eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=shared_name,
name=name,
graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
-
+ self._shape = initial_value.shape
+ # pylint: disable=protected-access
+ if (self._in_graph_mode and initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
+ raise ValueError(
+ "Initializer for variable %s is from inside a control-flow "
+ "construct, such as a loop or conditional. When creating a "
+ "variable inside a loop or conditional, use a lambda as the "
+ "initializer." % name)
+ # pylint: enable=protected-access
self._unique_id = shared_name
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index b2c6937..5d94946 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,14 +29,15 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -103,6 +104,87 @@
rewrite=rewrite, replace_global=replace_global)
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+ r"""Formats a string template using a list of tensors.
+
+ Formats a string template using a list of tensors, abbreviating tensors by
+ only printing the first and last `summarize` elements of each dimension
+ (recursively). If formatting only one tensor into a template, the tensor does
+ not have to be wrapped in a list.
+
+ Example:
+ Formatting a single-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ formatted = tf.strings.format("tensor: {}, suffix", tensor)
+ out = sess.run(formatted)
+ expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+ assert(out.decode() == expected)
+ ```
+
+ Formatting a multi-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor_one = tf.reshape(tf.range(100), [10, 10])
+ tensor_two = tf.range(10)
+ formatted = tf.strings.format("first: {}, second: {}, suffix",
+ (tensor_one, tensor_two))
+
+ out = sess.run(formatted)
+ expected = ("first: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+ assert(out.decode() == expected)
+ ```
+
+ Args:
+ template: A string template to format tensor values into.
+ inputs: A list of `Tensor` objects, or a single Tensor.
+ The list of tensors to format into the template string. If a solitary
+ tensor is passed in, the input tensor will automatically be wrapped as a
+ list.
+ placeholder: An optional `string`. Defaults to `{}`.
+ At each placeholder occurring in the template, a subsequent tensor
+ will be inserted.
+ summarize: An optional `int`. Defaults to `3`.
+ When formatting the tensors, show the first and last `summarize`
+ entries of each tensor dimension (recursively). If set to -1, all
+ elements of the tensor will be shown.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`.
+
+ Raises:
+ ValueError: if the number of placeholders does not match the number of
+ inputs.
+ """
+ # If there is only one tensor to format, we will automatically wrap it in a
+ # list to simplify the user experience
+ if tensor_util.is_tensor(inputs):
+ inputs = [inputs]
+ if template.count(placeholder) != len(inputs):
+ raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+ " provided as input" % (template.count(placeholder),
+ len(inputs)))
+
+ return gen_string_ops.string_format(inputs,
+ template=template,
+ placeholder=placeholder,
+ summarize=summarize,
+ name=name)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
new file mode 100644
index 0000000..875be31
--- /dev/null
+++ b/tensorflow/python/ops/while_v2.py
@@ -0,0 +1,580 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""while_v2 and gradient.
+
+This is a version of while_loop that emits a single While op, as well as the
+gradient function for While ops produced by while_loop. This will eventually
+replace the current tf.while_loop implementation once it reaches feature and
+performance parity.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.util import nest
+
+# pylint: disable=protected-access
+
+# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
+# control dependencies on external nodes with at least 1 output.
+# Another idea is to create const nodes outside the loop and add control edges
+# to them and then pass those in as data inputs. This should probably be
+# handled in the CapturingGraph itself.
+
+
+def while_loop(cond, body, loop_vars, name=None):
+ """Like tf.while_loop, except emits a single While op."""
+ if not name:
+ name = "while"
+
+ with ops.name_scope(name) as scope:
+ with ops.name_scope(None):
+ cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
+ body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
+
+ flattened_loop_vars = nest.flatten(loop_vars)
+ num_outputs = len(flattened_loop_vars)
+
+ # Add loop counter needed for computing gradients.
+ flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
+ ] + flattened_loop_vars
+
+ # Build a `cond` wrapper that can handle the extra counter loop_var.
+ def wrapped_cond(unused_loop_counter, *loop_vars):
+ return cond(*loop_vars)
+
+ cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
+ flattened_loop_vars, {})
+
+ # Add external_captures of cond to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+
+ def wrapped_body(loop_counter, *args):
+ """Loop body augmented with counter update.
+
+ Args:
+ loop_counter: Loop counter which needs to be incremented in the body.
+ *args: List of args
+ args[:num_outputs] - Args for the original loop body.
+ args[num_outputs:] - External captures of cond. These get passed
+ through as is.
+
+ Returns:
+ A list of tensors the same length as args.
+ """
+ outputs = body(*args[:num_outputs])
+ if not isinstance(outputs, collections.Sequence):
+ outputs = [outputs]
+
+ # Return the external_captures of cond_graph as is, i.e., treat them as
+ # loop invariants.
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
+
+ body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
+ flattened_loop_vars, {})
+ # Add external captures of body to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ body_graph.outputs.extend(body_graph.internal_captures)
+
+ # Capture `external_captures` of `body_graph` in `cond_graph` so that it
+ # expects to receive those as arguments.
+ # TODO(srbs): Dedup tensors that are captured in both the cond and body.
+ # This logic already exists in cond_v2.
+ with cond_graph.as_default():
+ for external_capture in body_graph.external_captures:
+ cond_graph.capture(external_capture)
+
+ # Export all tensors in the loop body that may be needed for gradient
+ # computation. We do this by accumulating the intermediate values in
+ # TensorLists.
+ intermediate_tensors = _get_intermediates(body_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ # TODO(srbs): Cache and re-use empty tensor lists.
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(
+ intermediate_tensor.shape))
+ flattened_loop_vars.append(tensor_list)
+ with cond_graph.as_default():
+ # Add a placeholder to cond_graph's inputs corresponding to the
+ # tensor_list.
+ cond_graph.capture(tensor_list)
+ with body_graph.as_default():
+ # Push the intermediate tensor to the tensor list. This captures the
+ # `tensor_list` as well.
+ appended_tensor_list = list_ops.tensor_list_push_back(
+ tensor_list,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_graph.outputs.append(appended_tensor_list)
+
+ outputs = gen_functional_ops._while(
+ flattened_loop_vars,
+ cond_v2._create_new_tf_function(cond_graph),
+ cond_v2._create_new_tf_function(body_graph),
+ name=scope)
+
+ _copy_handle_data(body_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # First var is loop counter.
+ if num_outputs == 1:
+ return outputs[1]
+ else:
+ return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
+
+
+@ops.RegisterGradient("While")
+def _WhileGrad(op, *grads): # pylint: disable=invalid-name
+ """The gradient of a While op produced by while_loop."""
+ body_graph = _get_body_graph(op)
+
+ # Replace None gradients with zeros. This is needed because `grads` could have
+ # None incoming gradients for the TensorLists. If we pass None's through, the
+ # custom gradient of TensorListPopBack will create an EmptyTensorList inside
+ # the FuncGraph which is undesirable.
+ # TODO(b/80444525): There might be an issue with treating no gradient as zero
+ # gradient in certain cases. Consider replacing None gradients with Zeros
+ # for accumulators only.
+ grads = [
+ g if g is not None else array_ops.zeros_like(output)
+ for g, output in zip(grads, op.outputs)
+ ]
+
+ body_grad_graph, args = _create_grad_func(
+ body_graph, grads,
+ _get_unique_name("%s_grad" % body_graph.name), op)
+
+ intermediate_tensors = _get_intermediates(body_grad_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
+ with body_grad_graph.as_default():
+ tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
+ # Push the intermediate tensor to the tensor list.
+ appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_grad_graph.outputs.append(appended_tensor_list)
+
+ def grad_cond(counter, max_iters, *unused_args):
+ return counter < max_iters
+
+ loop_vars = args + body_grad_graph.external_captures
+ cond_grad_graph = function.func_graph_from_py_func(
+ _get_unique_name("%s_grad_cond" % op.name),
+ grad_cond, loop_vars, {})
+
+ assert len(loop_vars) == len(body_grad_graph.inputs)
+ assert len(loop_vars) == len(body_grad_graph.outputs)
+ assert len(loop_vars) == len(cond_grad_graph.inputs)
+
+ outputs = gen_functional_ops._while(
+ loop_vars,
+ cond_v2._create_new_tf_function(cond_grad_graph),
+ cond_v2._create_new_tf_function(body_grad_graph),
+ name=_get_unique_name("%s_grad" % op.name))
+
+ _copy_handle_data(body_grad_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # outputs[0] is the loop counter.
+ # outputs[1] is the total number of loop iterations.
+ return outputs[2:2 + len(op.inputs)]
+
+
+# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
+def _get_body_graph(while_op):
+ """Returns `FuncGraph` for the while body.
+
+ Args:
+ while_op: The While Operation.
+
+ Returns:
+ `FuncGraph` for the while body.
+ """
+ extra_inputs = list(while_op.inputs)
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = while_op.get_attr("body").name
+ fdef = while_op.graph._get_function(func_name).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+ func_graph._while = while_op
+ return func_graph
+
+
+def _create_grad_func(func_graph, grads, name, while_op):
+ """Builds and returns the gradient FuncGraph of `func_graph` and its args.
+
+ The returned grad_func_graph must be called with the returned
+ args + grad_func_graph.captures.
+
+ Args:
+ func_graph: FuncGraph for the forward body function.
+ grads: The incoming grads for `func_graph`'s outputs.
+ name: Name of the returned gradient function.
+ while_op: The forward While op.
+
+ Returns:
+ 2-tuple of (grad_func_graph, args).
+ """
+ assert len(func_graph.outputs) == len(grads)
+
+ loop_counter = constant_op.constant(0.)
+ # TODO(srbs): For nested while loops will need to lookup this value from
+ # the accumulator of the enclosing while loop. For now use as is assuming
+ # there is no nesting.
+ num_iters_t = while_op.outputs[0]
+
+ args = [loop_counter, num_iters_t] + grads
+
+ # Note: The returned function does not have `args` in the list of
+ # `external_captures`.
+ grad_func_graph = function.func_graph_from_py_func(
+ name,
+ lambda *args: _grad_fn(func_graph, args),
+ args, {},
+ func_graph=_WhileBodyGradFuncGraph(name, func_graph))
+
+ # Add the popped accumulators to the list of outputs.
+ for internal_capture in grad_func_graph.internal_captures:
+ grad_func_graph.outputs.append(
+ grad_func_graph.popped_tensor_lists[internal_capture])
+
+ return grad_func_graph, args
+
+
+def _grad_fn(func_graph, args):
+ """Computes the gradient of `func_graph` in the current graph.
+
+ This function builds the gradient graph of the corresponding forward-pass
+ `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
+
+ Args:
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
+ args: The input arguments. args[0] - Loop counter args[1] - Total number of
+ iterations.
+ args[2:] - Incoming gradients for `func_graph.outputs`.
+
+ Returns:
+ The output gradient Tensors.
+ """
+ xs = func_graph.inputs
+ ys = func_graph.outputs
+ grad_ys = args[2:]
+
+ # Build the gradient graph. Note that this builds the gradient computation of
+ # func_graph in the current graph, which requires capturing tensors from
+ # func_graph. The captured func_graph tensors are resolved to external tensors
+ # in _resolve_grad_inputs.
+ # TODO(srbs): Mark GradientsHelper as public?
+ grad_outs = gradients_impl._GradientsHelper(
+ ys, xs, grad_ys=grad_ys, src_graph=func_graph)
+
+ assert all([g is not None for g in grad_outs])
+ counter = args[0]
+ total_iters = args[1]
+ return [counter + 1, total_iters] + grad_outs
+
+
+def _get_intermediates(func_graph):
+ """Returns all tensors in `func_graph` that should be accumulated."""
+ # We currently accumulate output tensors of most ops in the function and rely
+ # on the pruning pass to get rid of the unused accumulators at runtime.
+ # However, this can bloat the GraphDef and make debugging harder so we perform
+ # some optimizations.
+ #
+ # Optimization we currently perform:
+ # 1. We do not accumulate tensors which already have an accumulator
+ # in the loop body.
+ # 2. We do not accumulate outputs of Identity nodes. When building the
+ # FuncGraph, we add an Identity node for each output (see
+ # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+ # of all these nodes bloats the GraphDef quite a bit so we remove those.
+ # Since the gradient of an Identity node does not rely on its forward op's
+ # input this is safe to do.
+ #
+ # Other possible optimizations:
+ # 1. Only accumulate tensors that will be required by the backward pass.
+ # This will require running the gradient pass and hence would increase the
+ # graph building time for the forward pass.
+ # 2. Do not accumulate Const nodes created inside the loop body.
+ # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants.
+ # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer
+ # since it requires knowledge of the while loop semantics. If so, consider
+ # doing those here.
+ intermediates = []
+
+ for op in func_graph.get_operations():
+ if op.type == "Identity":
+ continue
+ for o in op.outputs:
+ if (o != func_graph.inputs[0] and # Loop counter.
+ _get_accumulator(o) is None): # Has existing accumulator.
+ intermediates.append(o)
+ return intermediates
+
+
+def _get_accumulator(tensor):
+ r"""Returns TensorList if any containing accumulated values of tensor.
+
+ We try to find a pattern of the form:
+
+ input_tl tensor
+ \ /
+ (TensorListPushBack)
+ |
+ output_tl
+
+ which satisfies the following conditions:
+
+ 1. input_tl must be in tensor.graph.inputs.
+ 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
+ 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
+
+ output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
+ returned if such a pattern is found else None is returned.
+
+ Args:
+ tensor: The Tensor to be accumulated.
+
+ Returns:
+ A variant tensor in the same graph as `tensor` or None if no accumulator is
+ found.
+ """
+ assert isinstance(tensor.graph, function.FuncGraph)
+
+ def get_func_graph_output(t):
+ """Returns t or Identity(t) whichever exists in graph outputs else None."""
+ if t in tensor.graph.outputs:
+ return t
+ # tf.defun adds an Identity for each output, check whether that is the case.
+ identity_op = t.consumers()[0]
+ if (identity_op.type == "Identity" and
+ identity_op.outputs[0] in tensor.graph.outputs):
+ return identity_op.outputs[0]
+ return None
+
+ for consumer in tensor.consumers():
+ # Find the consumer that is a TensorListPushBack node whose TensorList input
+ # is in the list of function inputs.
+ if (consumer.type != "TensorListPushBack" or
+ consumer.inputs[0] not in tensor.graph.inputs):
+ continue
+
+ output = get_func_graph_output(consumer.outputs[0])
+ if output is None:
+ # The TensorList output of `consumer` is not in the list of function
+ # outputs.
+ continue
+
+ accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
+ accum_output_idx = tensor.graph.outputs.index(output)
+ if accum_input_idx == accum_output_idx:
+ return output
+ return None
+
+
+# TODO(srbs): Add to common utils for cond_v2 and while_v2.
+def _get_unique_name(name):
+ """Returns a name that is unique in the root graph of `func_graph`.
+
+ Args:
+ name: String to uniquify.
+
+ Returns:
+ A string.
+ """
+ with ops.init_scope():
+ return ops.get_default_graph().unique_name(name)
+
+
+class _WhileBodyGradFuncGraph(function.FuncGraph):
+ """FuncGraph for the gradient function of the body of a While op.
+
+ Contains the logic for capturing the tensors from the body of the forward
+ While op which is as follows:
+ 1. Find the accumulator for that tensor.
+ 2. Capture the forward While op output tensor corresponding to the
+ accumulator in this FuncGraph.
+ 3. Pop a value from the captured placeholder and use it as the captured value
+ for the forward pass tensor.
+
+ This only allows capturing tensors in the forward graph. A ValueError is
+ raised if an attempt is made to capture a tensor not in the forward graph.
+ To manually capture capture a tensor that is not in the forward graph, call
+ `capture` with `whitelisted=True`.
+
+ Note: The `captures` dict does not contain the forward tensor since it is not
+ directly captured. It contains the accumulator corresponding to this forward
+ tensor.
+
+ Attributes:
+ popped_tensor_lists: Dict from the captured accumulator placeholder to the
+ TensorList obtained after popping the intermediate tensor from it. The
+ values of this dict need to be added to the list of outputs.
+ """
+
+ def __init__(self, name, forward_graph):
+ super(_WhileBodyGradFuncGraph, self).__init__(name)
+ self.popped_tensor_lists = {}
+ # FuncGraph for the body of the forward While op.
+ self._forward_graph = forward_graph
+ # Dict from forward intermediate tensor to the corresponding "popped" tensor
+ # in this graph.
+ self._indirect_captures = {}
+ # Dict from forward graph tensor to the While op output corresponding to its
+ # accumulator.
+ self._tensor_to_accumulator = {}
+
+ def capture(self, tensor, name=None, whitelisted=False):
+ """Selectively captures external tensors.
+
+ If `whitelisted` is False only allows capturing tensors in the
+ `_forward_graph`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+ whitelisted: If False (default), only allows capturing tensors from the
+ forward graph.
+
+ Returns:
+ The placeholder in this graph for the tensor.
+
+ Raises:
+ ValueError: If attempting to capture an external tensor not in the forward
+ graph with `whitelisted` set to False.
+ """
+ if (not whitelisted and tensor.graph is not self and
+ tensor.graph != self._forward_graph):
+ raise ValueError("Attempting to capture tensor", str(tensor),
+ " which is not in the forward graph but in ",
+ _graph_name(tensor.graph), ".")
+ return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
+
+ def _capture_helper(self, tensor, name):
+ if tensor.graph is not self._forward_graph:
+ return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
+
+ captured_tensor = self._indirect_captures.get(tensor)
+ if captured_tensor is not None:
+ # For GradientTape housekeeping.
+ assert self._tensor_to_accumulator[tensor] in self.captures
+ super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ self._tensor_to_accumulator[tensor], name)
+ return captured_tensor
+
+ assert tensor not in self._tensor_to_accumulator
+
+ accumulator = None
+
+ # Find the TensorList that was used to accumulate the tensors of this
+ # intermediate tensor.
+ accumulator = _get_accumulator(tensor)
+ if accumulator is None:
+ raise ValueError("Reference to un-accumulated intermediate tensor: ",
+ tensor.name)
+ assert accumulator.graph == self._forward_graph
+ # Get the While op output corresponding to the accumulator.
+ accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs
+ .index(accumulator)]
+
+ assert accumulator.graph == self._forward_graph.outer_graph
+ self._tensor_to_accumulator[tensor] = accumulator
+
+ # Capture the `accumulator`.
+ accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ accumulator, name)
+ new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
+ accumulator_ph, element_dtype=tensor.dtype)
+ self._indirect_captures[tensor] = captured_tensor
+ self.popped_tensor_lists[accumulator_ph] = new_tensor_list
+ return captured_tensor
+
+
+def _copy_handle_data(src_tensors, tgt_tensors):
+ for src_t, tgt_t in zip(src_tensors, tgt_tensors):
+ function._copy_handle_data(src_t, tgt_t)
+
+
+# TODO(srbs): Move to common utils for cond_v2 and while_v2.
+def _maybe_set_lowering_attr(op):
+ """Sets the flag to enable lowering on the `While` op if necessary.
+
+ Lowering allows while_v2 to avoid some of the limitations of Functions,
+ allowing users to specify devices & colocation inside of while_v2
+ branches, and enabling non-strict evaluation & partial pruning of while_v2
+ branches. This brings while_v2 closer to feature parity with
+ tf.while_loop.
+
+ However, we do not lower `While` in the XLA context because it is easier
+ for XLA to apply its own optimizations when dealing with un-lowered
+ `While` operators than with low-level control flow primitives.
+
+ Args:
+ op: The While op.
+ """
+ if not control_flow_util.IsInXLAContext(op):
+ # pylint: disable=protected-access
+ op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
+ # pylint: enable=protected-access
+
+
+def _get_tensor_convertible_shape(shape):
+ assert isinstance(shape, tensor_shape.TensorShape)
+ if shape.is_fully_defined():
+ return shape
+ if not shape: # Unknown shape.
+ return -1
+ # Partially defined shape.
+ shape_list = shape.as_list()
+ shape_list = [s if s is not None else -1 for s in shape_list]
+ return ops.convert_to_tensor(shape_list)
+
+
+def _graph_name(graph):
+ if isinstance(graph, function.FuncGraph):
+ return graph.name
+ return "Base"
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/profiler/pprof_profiler_test.py b/tensorflow/python/profiler/pprof_profiler_test.py
index c2469f0..11a3487 100644
--- a/tensorflow/python/profiler/pprof_profiler_test.py
+++ b/tensorflow/python/profiler/pprof_profiler_test.py
@@ -141,7 +141,7 @@
run_metadata = config_pb2.RunMetadata()
num_iters = 5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, num_iters)
b = lambda i: math_ops.add(i, 1)
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
index 5c0c578..f072427 100644
--- a/tensorflow/python/pywrap_tensorflow.py
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -68,7 +68,7 @@
sys.setdlopenflags(_default_dlopen_flags)
except ImportError:
msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
-See https://www.tensorflow.org/install/install_sources#common_installation_problems\n
+See https://www.tensorflow.org/install/errors\n
for some common reasons and solutions. Include the entire stack trace
above this error message when asking for help.""" % traceback.format_exc()
raise ImportError(msg)
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index dc990c2..670230e 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -286,7 +286,7 @@
def testAddingSummariesFromSessionRunCalls(self):
test_dir = self._CleanTestDir("global_step")
sw = self._FileWriter(test_dir)
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(1, dtype=dtypes.int32, shape=[])
l = constant_op.constant(2, dtype=dtypes.int64, shape=[])
# Test the summary can be passed serialized.
@@ -437,7 +437,7 @@
# Pass in test_session() as the session. It will be cached during this
# test method invocation so that any other use of test_session() with no
# graph should result in re-using the same underlying Session.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
kwargs["session"] = sess
return writer.FileWriter(*args, **kwargs)
return writer.FileWriter(*args, **kwargs)
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index fcb3cea..a39c046 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -129,7 +129,7 @@
self.assertProtoEquals(expected_output, output)
def testFoldBatchNorms(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
@@ -161,7 +161,7 @@
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
original_graph_def)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -224,7 +224,7 @@
self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -242,7 +242,7 @@
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -255,7 +255,7 @@
self.assertNotEqual("ResizeBilinear", node.op)
def testFuseResizeAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -271,7 +271,7 @@
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -284,7 +284,7 @@
def testFusePadAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -300,7 +300,7 @@
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 3508b98..cc0da26 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -34,7 +34,7 @@
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
- [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index 7f99d81..a4580d6 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -22,8 +22,7 @@
#include <map>
#include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
-
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -359,9 +358,8 @@
bool ThreadDimOk(const DeviceDescription &device_description,
const ThreadDim &thread_dim);
-// [deprecated] Use MathUtil::CeilOfRatio directly instead.
-//
// Equivalent to ceil(double(element_count) / threads_per_block).
+ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
uint64 DivideCeil(uint64 x, uint64 y);
// Calculate the number of threads/blocks required to process element_count
diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h
index 8e3c4ca..5f4e586 100644
--- a/tensorflow/stream_executor/lib/array_slice.h
+++ b/tensorflow/stream_executor/lib/array_slice.h
@@ -16,13 +16,15 @@
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "absl/types/span.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::MutableArraySlice;
+template <typename T>
+using ArraySlice = absl::Span<const T>;
+template <typename T>
+using MutableArraySlice = absl::Span<T>;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h
index 40bdddb..0198947 100644
--- a/tensorflow/stream_executor/lib/inlined_vector.h
+++ b/tensorflow/stream_executor/lib/inlined_vector.h
@@ -16,12 +16,12 @@
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::InlinedVector;
+using absl::InlinedVector;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
index c959e4d..3688d7b 100644
--- a/tensorflow/stream_executor/lib/strcat.h
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -18,13 +18,13 @@
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "absl/strings/str_cat.h"
namespace stream_executor {
namespace port {
-using tensorflow::strings::StrCat;
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
+using absl::StrCat;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h
index b80de5d..7624910 100644
--- a/tensorflow/stream_executor/lib/stringpiece.h
+++ b/tensorflow/stream_executor/lib/stringpiece.h
@@ -16,13 +16,12 @@
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/stream_executor/platform/port.h"
+#include "absl/strings/string_view.h"
namespace stream_executor {
namespace port {
-using tensorflow::StringPiece;
+using StringPiece = absl::string_view;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 49628ec..3065b5c 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -18,6 +18,7 @@
#include <map>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/fft.h"
@@ -97,6 +98,7 @@
// TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
// on MultiPlatformManager / PlatformId.
template <typename FactoryT>
+ ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
PluginId plugin_id);
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index d04025b..4a8a270 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -22,6 +22,7 @@
#include <tuple>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -81,8 +82,8 @@
port::Status Init();
port::Status Init(int device_ordinal, DeviceOptions device_options);
- // DEPRECATED: Do not use; use platform() instead.
// Returns the platform that this StreamExecutor is acting upon.
+ ABSL_DEPRECATED("Use platform() instead.")
PlatformKind platform_kind() const { return platform_kind_; }
// Returns a reference to the platform that created this executor.
@@ -255,15 +256,15 @@
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the host source to the device destination.
- //
- // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
uint64 size) SE_MUST_USE_RESULT;
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the device source to the host destination.
- //
- // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
uint64 size) SE_MUST_USE_RESULT;
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78..9e429a3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesClassifier"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea..56af1d1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesRegressor"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 14ab885..fbc58e5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1593,6 +1593,10 @@
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -1801,6 +1805,10 @@
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
+ name: "searchsorted"
+ argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 018be7b..c81c156 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78..9e429a3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesClassifier"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea..56af1d1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesRegressor"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 323d2fc..7eca26b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -581,10 +581,6 @@
argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "Print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "abs"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1541,6 +1537,10 @@
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -1725,6 +1725,10 @@
argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "searchsorted"
+ argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 018be7b..c81c156 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md
index f2161b7..e2fd977 100644
--- a/tensorflow/tools/ci_build/README.md
+++ b/tensorflow/tools/ci_build/README.md
@@ -24,7 +24,7 @@
### Run TensorFlow CI Scripts Natively on your Machine
-1. Follow the instructions at https://www.tensorflow.org/install/install_sources,
+1. Follow the instructions at https://www.tensorflow.org/install/source,
but stop when you get to the section "Configure the installation". You do not
need to configure the installation to run the CI scripts.
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 4b762bf..17198a6 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -64,7 +64,7 @@
fi
done
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
# PIP tests should have a "different" path. Different than the one we place
# virtualenv, because we are deleting and recreating it here.
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index cc09784c..49a9048 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -147,7 +147,7 @@
ANDROID_CMD="${CI_BUILD_DIR}/builds/android.sh"
ANDROID_FULL_CMD="${CI_BUILD_DIR}/builds/android_full.sh"
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
PARALLEL_GPU_TEST_CMD='//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute'
BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh"
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 03a2a07..cd7206b 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -21,8 +21,8 @@
# Required environment variables:
# TF_GPU_COUNT = Number of GPUs available.
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
-TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
+TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-8}
# We want to allow running one of the following configs:
# - 4 tests per GPU on k80
# - 8 tests per GPU on p100
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index 28d5565..34847e6 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -122,7 +122,7 @@
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow_gpu-*.whl)
reinstall_tensorflow_pip ${PIP_NAME}
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index 228d5ee..f8ed74a 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -23,7 +23,7 @@
./local_test.sh ${whl_file_url}
```
-For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/install_linux#the_url_of_the_tensorflow_python_package) for Ubuntu.
+For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/pip) for Ubuntu.
**2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the
test suite on it**
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 31a3712..f86cb03 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -114,6 +114,7 @@
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:while_v2",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
]
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index d2e6f8d..d0531f8 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -491,11 +491,11 @@
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/ad72545325c087661feb3512efa54ebe5f888736.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/ad72545325c087661feb3512efa54ebe5f888736.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/db98902adc6431c9cc4ddec50fe174cfc9e626d6.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/db98902adc6431c9cc4ddec50fe174cfc9e626d6.tar.gz",
],
- sha256 = "66ed69443af00fbf9b912edbb6bc0fa796a12766b5e9ad504eb6b20f813dc163",
- strip_prefix = "llvm-ad72545325c087661feb3512efa54ebe5f888736",
+ sha256 = "8c02d312b3d417cf9bc7e58ff53c2528bf77a5d839ce4a23b95bd04b9e5da023",
+ strip_prefix = "llvm-db98902adc6431c9cc4ddec50fe174cfc9e626d6",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD
index 29159c9..a04347b 100644
--- a/third_party/cub.BUILD
+++ b/third_party/cub.BUILD
@@ -20,6 +20,7 @@
cc_library(
name = "cub",
hdrs = if_cuda([":cub_header_files"]),
+ include_prefix = "third_party",
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index 4303751..7256a7d 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -32,6 +32,6 @@
remote_execution_properties = """
properties: {
name: "container-image"
- value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:ae58329b961e7c17d89725bf8fd72dfbd5850f4f3313de58e0cafbf5b0343735"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:06b585f42eed3b2030e9566b8f88f48d7472fa0f47e59765bc115376c8801bdf"
}""",
)