| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/data/dataset_utils.h" |
| |
| #include <functional> |
| #include <memory> |
| #include <queue> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/strings/str_join.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/dataset.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op_def_builder.h" |
| #include "tensorflow/core/framework/op_def_util.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_util.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/graph/graph_def_builder.h" |
| #include "tensorflow/core/lib/core/blocking_counter.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/hash/hash.h" |
| #include "tensorflow/core/lib/strings/proto_serialization.h" |
| #include "tensorflow/core/platform/host_info.h" |
| #include "tensorflow/core/platform/regexp.h" |
| #include "tensorflow/core/util/determinism.h" |
| #include "tensorflow/core/util/work_sharder.h" |
| |
| namespace tensorflow { |
| namespace data { |
| namespace { |
| |
| constexpr char kOutputSize[] = "output_size"; |
| constexpr char kCode[] = "code"; |
| constexpr char kMessage[] = "msg"; |
| constexpr char kOutput[] = "output"; |
| |
| static mutex* get_dataset_experiment_registry_lock() { |
| static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED); |
| return &dataset_experiment_registry_lock; |
| } |
| |
| static absl::flat_hash_map<string, int64_t>* get_dataset_experiments() { |
| static absl::flat_hash_map<string, int64_t>* experiments = |
| new absl::flat_hash_map<string, int64_t>; |
| return experiments; |
| } |
| |
| // Use "Opt" suffix so that they are not confused with the enums in Options |
| // proto. |
| constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion"; |
| constexpr char kNoopEliminationOpt[] = "noop_elimination"; |
| constexpr char kMapParallelizationOpt[] = "map_parallelization"; |
| constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion"; |
| constexpr char kFilterFusionOpt[] = "filter_fusion"; |
| constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion"; |
| constexpr char kMapFusionOpt[] = "map_fusion"; |
| constexpr char kParallelBatchOpt[] = "parallel_batch"; |
| constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes"; |
| constexpr char kDisablePrefetchLegacyAutotuneOpt[] = |
| "disable_prefetch_legacy_autotune"; |
| constexpr char kMakeSloppyOpt[] = "make_sloppy"; |
| constexpr char kUseChooseFastestOpt[] = "use_choose_fastest"; |
| constexpr char kBatchParallelizationOpt[] = "batch_parallelization"; |
| constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent"; |
| constexpr char kInjectPrefetchOpt[] = "inject_prefetch"; |
| constexpr char kInjectPrefetchEligibleOpt[] = "inject_prefetch_eligible"; |
| constexpr char kAutotuneOpt[] = "autotune"; |
| constexpr char kSlackOpt[] = "slack"; |
| constexpr char kSlackPeriodOpt[] = "slack_period"; |
| constexpr char kMakeDeterministicOpt[] = "make_deterministic"; |
| |
| void DefaultOptimizationGraphRewrites( |
| const Options& options, absl::flat_hash_set<tstring>* optimization_enabled, |
| absl::flat_hash_set<tstring>* optimization_disabled, |
| absl::flat_hash_set<tstring>* optimization_default) { |
| const auto& optimization_options = options.optimization_options(); |
| if (optimization_options.optional_apply_default_optimizations_case() != |
| OptimizationOptions::kApplyDefaultOptimizations || |
| optimization_options.apply_default_optimizations()) { |
| if (optimization_options.optional_map_and_batch_fusion_case() != |
| OptimizationOptions::kMapAndBatchFusion) { |
| optimization_default->insert(kMapAndBatchFusionOpt); |
| } |
| if (optimization_options.optional_noop_elimination_case() != |
| OptimizationOptions::kNoopElimination) { |
| optimization_default->insert(kNoopEliminationOpt); |
| } |
| if (optimization_options.optional_map_parallelization_case() != |
| OptimizationOptions::kMapParallelization) { |
| optimization_default->insert(kMapParallelizationOpt); |
| } |
| if (optimization_options.optional_shuffle_and_repeat_fusion_case() != |
| OptimizationOptions::kShuffleAndRepeatFusion) { |
| optimization_default->insert(kShuffleAndRepeatFusionOpt); |
| } |
| if (optimization_options.optional_parallel_batch_case() != |
| OptimizationOptions::kParallelBatch) { |
| optimization_default->insert(kParallelBatchOpt); |
| } |
| } |
| if (OpDeterminismRequired()) { |
| optimization_enabled->insert(kMakeDeterministicOpt); |
| } |
| if (optimization_options.optional_filter_fusion_case() == |
| OptimizationOptions::kFilterFusion) { |
| if (optimization_options.filter_fusion()) { |
| optimization_enabled->insert(kFilterFusionOpt); |
| } else { |
| optimization_disabled->insert(kFilterFusionOpt); |
| } |
| } |
| if (optimization_options.optional_map_and_batch_fusion_case() == |
| OptimizationOptions::kMapAndBatchFusion) { |
| if (optimization_options.map_and_batch_fusion()) { |
| optimization_enabled->insert(kMapAndBatchFusionOpt); |
| } else { |
| optimization_disabled->insert(kMapAndBatchFusionOpt); |
| } |
| } |
| if (optimization_options.optional_map_and_filter_fusion_case() == |
| OptimizationOptions::kMapAndFilterFusion) { |
| if (optimization_options.map_and_filter_fusion()) { |
| optimization_enabled->insert(kMapAndFilterFusionOpt); |
| } else { |
| optimization_disabled->insert(kMapAndFilterFusionOpt); |
| } |
| } |
| if (optimization_options.optional_map_parallelization_case() == |
| OptimizationOptions::kMapParallelization) { |
| if (optimization_options.map_parallelization()) { |
| optimization_enabled->insert(kMapParallelizationOpt); |
| } else { |
| optimization_disabled->insert(kMapParallelizationOpt); |
| } |
| } |
| if (optimization_options.optional_map_fusion_case() == |
| OptimizationOptions::kMapFusion) { |
| if (optimization_options.map_fusion()) { |
| optimization_enabled->insert(kMapFusionOpt); |
| } else { |
| optimization_disabled->insert(kMapFusionOpt); |
| } |
| } |
| if (optimization_options.optional_noop_elimination_case() == |
| OptimizationOptions::kNoopElimination) { |
| if (optimization_options.noop_elimination()) { |
| optimization_enabled->insert(kNoopEliminationOpt); |
| } else { |
| optimization_disabled->insert(kNoopEliminationOpt); |
| } |
| } |
| if (optimization_options.optional_parallel_batch_case() == |
| OptimizationOptions::kParallelBatch) { |
| if (optimization_options.parallel_batch()) { |
| optimization_enabled->insert(kParallelBatchOpt); |
| } else { |
| optimization_disabled->insert(kParallelBatchOpt); |
| } |
| } |
| if (optimization_options.optional_shuffle_and_repeat_fusion_case() == |
| OptimizationOptions::kShuffleAndRepeatFusion) { |
| if (optimization_options.shuffle_and_repeat_fusion()) { |
| optimization_enabled->insert(kShuffleAndRepeatFusionOpt); |
| } else { |
| optimization_disabled->insert(kShuffleAndRepeatFusionOpt); |
| } |
| } |
| } |
| |
| // Returns whether an op has been allowlisted as stateless. Uses a heuristic to |
| // allowlist source dataset ops which have been marked stateful due to |
| // b/65524810. Also looks up the `op_def->name` in the global |
| // `AllowlistedStatefulOpRegistry`. |
| bool IsOpAllowlisted(const OpDef* op_def) { |
| return (op_def->output_arg_size() == 1 && |
| op_def->output_arg(0).type() == DT_VARIANT && |
| (absl::EndsWith(op_def->name(), "Dataset") || |
| absl::EndsWith(op_def->name(), "DatasetV2"))) || |
| AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name()); |
| } |
| |
| } // namespace |
| |
| std::pair<int64_t, int64_t> MaybeOverrideSeeds( |
| std::pair<int64_t, int64_t> seeds) { |
| if (seeds.first == 0 && seeds.second == 0) { |
| return {random::New64(), random::New64()}; |
| } |
| return seeds; |
| } |
| |
| Status VerifyTypeMatch(const DataType& expected, const DataType& received, |
| int index) { |
| if (expected != received) { |
| return errors::InvalidArgument("Data type mismatch at component ", index, |
| ": expected ", DataTypeString(expected), |
| " but got ", DataTypeString(received), "."); |
| } |
| return Status::OK(); |
| } |
| |
| Status VerifyTypesMatch(const DataTypeVector& expected, |
| const DataTypeVector& received) { |
| if (expected.size() != received.size()) { |
| return errors::InvalidArgument( |
| "Number of components does not match: expected ", expected.size(), |
| " types but got ", received.size(), "."); |
| } |
| for (size_t i = 0; i < expected.size(); ++i) { |
| TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i)); |
| } |
| return Status::OK(); |
| } |
| |
| Status VerifyTypesMatch(const DataTypeVector& expected, |
| const std::vector<Tensor>& received) { |
| if (expected.size() != received.size()) { |
| return errors::InvalidArgument( |
| "Number of components does not match: expected ", expected.size(), |
| " types but got ", received.size(), "."); |
| } |
| for (size_t i = 0; i < expected.size(); ++i) { |
| TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i)); |
| } |
| return Status::OK(); |
| } |
| |
| Status VerifyShapeCompatible(const PartialTensorShape& expected, |
| const PartialTensorShape& received, int index) { |
| if (!expected.IsCompatibleWith(received)) { |
| return errors::InvalidArgument("Incompatible shapes at component ", index, |
| ": expected ", expected.DebugString(), |
| " but got ", received.DebugString(), "."); |
| } |
| return Status::OK(); |
| } |
| |
| Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, |
| const std::vector<PartialTensorShape>& received) { |
| if (expected.size() != received.size()) { |
| return errors::InvalidArgument( |
| "Number of components does not match: expected ", expected.size(), |
| " shapes but got ", received.size(), "."); |
| } |
| for (size_t i = 0; i < expected.size(); ++i) { |
| TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, |
| const std::vector<Tensor>& received) { |
| if (expected.size() != received.size()) { |
| return errors::InvalidArgument( |
| "Number of components does not match: expected ", expected.size(), |
| " shapes but got ", received.size(), "."); |
| } |
| for (size_t i = 0; i < expected.size(); ++i) { |
| TF_RETURN_IF_ERROR( |
| VerifyShapeCompatible(expected[i], received[i].shape(), i)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AddToFunctionLibrary(FunctionLibraryDefinition* base, |
| const FunctionLibraryDefinition& to_add) { |
| for (const auto& fn : to_add.ListFunctionNames()) { |
| if (auto found = base->Find(fn)) { |
| if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) { |
| return errors::InvalidArgument("Cannot add function '", fn, |
| "' because a different function with " |
| "the same signature already exists."); |
| } |
| TF_RETURN_IF_ERROR(base->RemoveFunction(fn)); |
| } |
| } |
| return base->AddLibrary(to_add); |
| } |
| |
| Status AddToFunctionLibrary(FunctionLibraryDefinition* base, |
| const FunctionDefLibrary& to_add) { |
| for (const auto& fd : to_add.function()) { |
| if (auto found = base->Find(fd.signature().name())) { |
| if (!OpDefEqual(found->signature(), fd.signature())) { |
| return errors::InvalidArgument("Cannot add function '", |
| fd.signature().name(), |
| "' because a different function with " |
| "the same signature already exists."); |
| } |
| TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name())); |
| } |
| } |
| return base->AddLibrary(to_add); |
| } |
| |
| Status IsFunctionStateful(const FunctionLibraryDefinition& library, |
| const FunctionDef& function_def) { |
| if (!function_def.signature().is_stateful()) { |
| return Status::OK(); |
| } |
| |
| for (const NodeDef& node_def : function_def.node_def()) { |
| TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def)); |
| } |
| return Status::OK(); |
| } |
| |
| Status IsNodeStateful(const FunctionLibraryDefinition& library, |
| const NodeDef& node) { |
| const OpDef* op_def; |
| |
| // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore |
| // `LookUpOpDef` errors here. |
| if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() || |
| IsOpAllowlisted(op_def) || !op_def->is_stateful() || |
| op_def->name() == "Assert") { |
| return Status::OK(); |
| } |
| |
| if (op_def->name() == "If") { |
| const FunctionDef* then_func = |
| library.Find(node.attr().at("then_branch").func().name()); |
| const FunctionDef* else_func = |
| library.Find(node.attr().at("else_branch").func().name()); |
| if (then_func != nullptr) { |
| TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func)); |
| } |
| if (else_func != nullptr) { |
| TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func)); |
| } |
| return Status::OK(); |
| } |
| |
| if (op_def->name() == "While") { |
| const FunctionDef* cond_func = |
| library.Find(node.attr().at("cond").func().name()); |
| const FunctionDef* body_func = |
| library.Find(node.attr().at("body").func().name()); |
| if (cond_func != nullptr) { |
| TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func)); |
| } |
| if (body_func != nullptr) { |
| TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func)); |
| } |
| return Status::OK(); |
| } |
| |
| return errors::FailedPrecondition(op_def->name(), " is stateful."); |
| } |
| |
| std::function<void(std::function<void()>)> RunnerWithMaxParallelism( |
| std::function<void(std::function<void()>)> runner, int max_parallelism) { |
| return std::bind( |
| [max_parallelism]( |
| // Note: `runner` is a const reference to avoid copying it. |
| const std::function<void(std::function<void()>)>& runner, |
| std::function<void()> fn) { |
| std::function<void()> scoped_fn = std::bind( |
| [max_parallelism](const std::function<void()>& fn) { |
| ScopedPerThreadMaxParallelism scope(max_parallelism); |
| fn(); |
| }, |
| std::move(fn)); |
| runner(std::move(scoped_fn)); |
| }, |
| std::move(runner), std::placeholders::_1); |
| } |
| |
| Status DeterminismPolicy::FromString(const std::string& s, |
| DeterminismPolicy* out) { |
| DeterminismPolicy::Type type; |
| if (s == DeterminismPolicy::kDeterministic) { |
| type = DeterminismPolicy::Type::kDeterministic; |
| } else if (s == DeterminismPolicy::kNondeterministic) { |
| type = DeterminismPolicy::Type::kNondeterministic; |
| } else if (s == DeterminismPolicy::kDefault) { |
| type = DeterminismPolicy::Type::kDefault; |
| } else { |
| return errors::InvalidArgument("Unrecognized determinism policy: ", s); |
| } |
| *out = DeterminismPolicy(type); |
| return Status::OK(); |
| } |
| |
| DeterminismPolicy::DeterminismPolicy(bool is_deterministic) { |
| if (is_deterministic) { |
| determinism_ = DeterminismPolicy::Type::kDeterministic; |
| } else { |
| determinism_ = DeterminismPolicy::Type::kNondeterministic; |
| } |
| } |
| |
| std::string DeterminismPolicy::String() const { |
| switch (determinism_) { |
| case DeterminismPolicy::Type::kDeterministic: |
| return DeterminismPolicy::kDeterministic; |
| case DeterminismPolicy::Type::kNondeterministic: |
| return DeterminismPolicy::kNondeterministic; |
| case DeterminismPolicy::Type::kDefault: |
| return DeterminismPolicy::kDefault; |
| default: |
| LOG(ERROR) << "Unrecognized determinism value"; |
| return "Unrecognized"; |
| } |
| } |
| |
| bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) { |
| if (!absl::StartsWith(op_to_match, op_prefix)) { |
| return false; |
| } |
| if (op_to_match.length() == op_prefix.length()) { |
| return true; |
| } |
| size_t index = op_to_match.length() - 1; |
| while (isdigit(op_to_match[index])) { |
| index--; |
| } |
| return (op_to_match[index] == 'V') && (op_prefix.length() == index); |
| } |
| |
| absl::flat_hash_set<string> GetExperiments() { |
| return GetExperiments(port::JobName(), |
| [](const tstring& str) { return Hash64(str); }); |
| } |
| |
| absl::flat_hash_set<string> GetExperiments( |
| const string& job_name, std::function<uint64(const string&)> hash_func) { |
| absl::flat_hash_set<string> experiments; |
| |
| if (job_name.empty()) { |
| return experiments; |
| } |
| |
| // Parse the opt-in and opt-out settings. |
| const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN"); |
| const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT"); |
| string opt_ins_raw; |
| if (opt_ins_raw_cs != nullptr) { |
| opt_ins_raw = string(opt_ins_raw_cs); |
| } |
| string opt_outs_raw; |
| if (opt_outs_raw_cs != nullptr) { |
| opt_outs_raw = string(opt_outs_raw_cs); |
| } |
| |
| // Identify opted out experiments. |
| absl::flat_hash_map<string, int64_t> live_experiments = |
| DatasetExperimentRegistry::Experiments(); |
| absl::flat_hash_set<string> opt_outs; |
| if (opt_outs_raw == "all") { |
| for (const auto& pair : live_experiments) { |
| opt_outs.insert(pair.first); |
| } |
| } else { |
| for (const auto& experiment : |
| str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) { |
| opt_outs.insert(experiment); |
| } |
| } |
| |
| // Include opted in experiments unless they are opted out. |
| if (opt_ins_raw == "all") { |
| for (const auto& pair : live_experiments) { |
| auto experiment = pair.first; |
| if (!opt_outs.contains(experiment)) { |
| experiments.insert(experiment); |
| } |
| } |
| } else { |
| for (const auto& experiment : |
| str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) { |
| if (!opt_outs.contains(experiment)) { |
| experiments.insert(experiment); |
| } |
| } |
| } |
| |
| // Stochastically include live experiments unless they are opted out. |
| for (const auto& pair : live_experiments) { |
| auto& experiment = pair.first; |
| if ((hash_func(strings::StrCat(job_name, experiment)) % 100 < |
| pair.second) && |
| !opt_outs.contains(experiment)) { |
| experiments.insert(experiment); |
| } |
| } |
| |
| return experiments; |
| } |
| |
| void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) { |
| if (!experiments.empty()) { |
| constexpr float TEN_MINUTES = 60.0 * 10.0; |
| LOG_EVERY_N_SEC(INFO, TEN_MINUTES) |
| << "The input pipeline is subject to the following tf.data experiments:" |
| << " " << absl::StrJoin(experiments, ", ") << ". " |
| << "See `go/tf-data-experiments` for more details."; |
| } |
| for (auto& experiment : experiments) { |
| metrics::RecordTFDataExperiment(experiment); |
| } |
| } |
| |
| void GetOptimizations(const Options& options, |
| absl::flat_hash_set<tstring>* optimizations_enabled, |
| absl::flat_hash_set<tstring>* optimizations_disabled, |
| absl::flat_hash_set<tstring>* optimizations_default) { |
| DefaultOptimizationGraphRewrites(options, optimizations_enabled, |
| optimizations_disabled, |
| optimizations_default); |
| if (!OpDeterminismRequired() && |
| options.optional_deterministic_case() == Options::kDeterministic && |
| !options.deterministic()) { |
| optimizations_enabled->insert(kMakeSloppyOpt); |
| } |
| if (options.optional_slack_case() == Options::kSlack) { |
| if (options.slack()) { |
| optimizations_enabled->insert(kSlackOpt); |
| } else { |
| optimizations_disabled->insert(kSlackOpt); |
| } |
| } |
| } |
| |
| Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) { |
| Tensor slice = tensor.SubSlice(index); |
| if (slice.IsAligned()) { |
| return slice; |
| } else { |
| return tensorflow::tensor::DeepCopy(slice); |
| } |
| } |
| |
| void StripDevicePlacement(FunctionDefLibrary* library) { |
| for (auto& function : (*library->mutable_function())) { |
| for (auto& node : (*function.mutable_node_def())) { |
| if (!node.device().empty()) { |
| *node.mutable_device() = ""; |
| } |
| } |
| } |
| } |
| |
| Status CopyPartialBatch(int64_t num_elements, const Tensor& value, |
| Tensor* output) { |
| switch (value.dtype()) { |
| #define HANDLE_TYPE(type) \ |
| case DataTypeToEnum<type>::value: { \ |
| auto output_t = output->flat_outer_dims<type>(); \ |
| auto value_t = value.flat_outer_dims<type>(); \ |
| for (size_t i = 0; i < num_elements; i++) { \ |
| output_t.template chip<0>(i) = value_t.template chip<0>(i); \ |
| } \ |
| return Status::OK(); \ |
| } |
| TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| default: |
| return errors::InvalidArgument("Unsupported data type: ", |
| DataTypeString(value.dtype())); |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, |
| int64_t batch_size, const string& iterator_prefix, |
| const string& batch_prefix, std::vector<Tensor>* batch) { |
| int64_t output_size; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| FullName(iterator_prefix, |
| strings::StrCat(batch_prefix, "_", kOutputSize)), |
| &output_size)); |
| batch->reserve(output_size); |
| for (int i = 0; i < output_size; i++) { |
| Tensor t; |
| TF_RETURN_IF_ERROR( |
| reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix), |
| strings::StrCat(kOutput, "_", i), &t)); |
| // If the batch was not full, we may have stored only the relevant slice. |
| // Since tensors in `BatchResult.output` are expected to have the leading |
| // dimension of size batch_size, we build a larger tensor and copy the slice |
| // read from the checkpoint into it. |
| if (t.dim_size(0) < batch_size) { |
| TensorShape component_shape(t.shape()); |
| component_shape.set_dim(0, batch_size); |
| AllocatorAttributes attr; |
| attr.set_gpu_compatible(true); |
| Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape); |
| TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t)); |
| batch->emplace_back(std::move(new_t)); |
| } else { |
| batch->emplace_back(std::move(t)); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status WriteBatch(int64_t batch_size, int64_t num_elements, |
| const string& iterator_prefix, const string& batch_prefix, |
| IteratorStateWriter* writer, std::vector<Tensor>* batch) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| FullName(iterator_prefix, |
| strings::StrCat(batch_prefix, "_", kOutputSize)), |
| batch->size())); |
| for (int i = 0; i < batch->size(); i++) { |
| // If the batch is not full, we only store the first `num_elements` values. |
| // The rest of the batch tensor is *uninitialized* and accessing that will |
| // raise msan errors. |
| if (num_elements < batch_size) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteTensor(FullName(iterator_prefix, batch_prefix), |
| strings::StrCat(kOutput, "_", i), |
| (*batch)[i].Slice(0, num_elements))); |
| } else { |
| TF_RETURN_IF_ERROR( |
| writer->WriteTensor(FullName(iterator_prefix, batch_prefix), |
| strings::StrCat(kOutput, "_", i), (*batch)[i])); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadStatus(const string& iterator_prefix, const string& prefix, |
| IteratorStateReader* reader, Status* status) { |
| int64_t code_int; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)), |
| &code_int)); |
| error::Code code = static_cast<error::Code>(code_int); |
| |
| if (code != error::Code::OK) { |
| tstring error_message; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)), |
| &error_message)); |
| *status = Status(code, error_message); |
| } else { |
| *status = Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| Status WriteStatus(const string& iterator_prefix, const string& prefix, |
| const Status& status, IteratorStateWriter* writer) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)), |
| static_cast<int64_t>(status.code()))); |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)), |
| status.error_message())); |
| } |
| return Status::OK(); |
| } |
| |
| Status ProcessBatch(int64_t batch_size, int64_t num_elements, |
| bool drop_remainder, const Status& status, |
| IteratorContext* ctx, std::vector<Tensor>* output, |
| bool* end_of_sequence, std::vector<Tensor>* batch) { |
| if (num_elements == 0) { |
| if (status.ok() || errors::IsOutOfRange(status)) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } else { |
| *end_of_sequence = false; |
| return status; |
| } |
| } |
| if (!status.ok() && !errors::IsOutOfRange(status)) { |
| *end_of_sequence = false; |
| return status; |
| } |
| if (num_elements < batch_size) { |
| if (drop_remainder) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| for (size_t i = 0; i < batch->size(); ++i) { |
| TensorShape component_shape((*batch)[i].shape()); |
| component_shape.set_dim(0, num_elements); |
| AllocatorAttributes attr; |
| attr.set_gpu_compatible(true); |
| output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(), |
| component_shape); |
| if (!output->back().IsInitialized()) { |
| return errors::ResourceExhausted( |
| "Failed to allocate memory for the batch of component ", i); |
| } |
| TF_RETURN_IF_ERROR( |
| CopyPartialBatch(num_elements, (*batch)[i], &output->back())); |
| } |
| } else { |
| *output = std::move(*batch); |
| } |
| *end_of_sequence = false; |
| return Status::OK(); |
| } |
| |
| Status CopyBatch(CopyBatchParams params, |
| const std::vector<std::vector<Tensor>>& batch_elements, |
| bool parallel_copy, |
| std::function<Status()> allocation_callback, |
| std::vector<Tensor>* out_tensors) { |
| const size_t num_tuple_components = batch_elements.at(0).size(); |
| out_tensors->reserve(num_tuple_components); |
| const int64_t num_batch_elements = batch_elements.size(); |
| for (size_t component_index = 0; component_index < num_tuple_components; |
| ++component_index) { |
| const Tensor& first_element = batch_elements.at(0)[component_index]; |
| TensorShape first_element_shape(first_element.shape()); |
| TensorShape batch_component_shape({num_batch_elements}); |
| batch_component_shape.AppendShape(first_element_shape); |
| out_tensors->emplace_back(params.allocator, first_element.dtype(), |
| batch_component_shape); |
| if (!out_tensors->back().IsInitialized()) { |
| return errors::ResourceExhausted( |
| "Failed to allocate memory for the batch of component ", |
| component_index); |
| } |
| } |
| if (allocation_callback) { |
| TF_RETURN_IF_ERROR(allocation_callback()); |
| } |
| for (size_t component_index = 0; component_index < num_tuple_components; |
| ++component_index) { |
| Tensor& batch_component = out_tensors->at(component_index); |
| const Tensor& first_element = batch_elements.at(0)[component_index]; |
| TensorShape first_element_shape(first_element.shape()); |
| // Build the output tuple component by copying one slice from each input |
| // element in the batch. |
| auto copy_element_fn = [component_index, &batch_elements, &batch_component, |
| &first_element_shape](int index) { |
| if (batch_elements.at(index)[component_index].shape() != |
| first_element_shape) { |
| return errors::InvalidArgument( |
| "Cannot batch tensors with different shapes in component ", |
| component_index, ". First element had shape ", |
| first_element_shape.DebugString(), " and element ", index, |
| " had shape ", |
| batch_elements.at(index)[component_index].shape().DebugString(), |
| "."); |
| } |
| return batch_util::CopyElementToSlice( |
| std::move(batch_elements.at(index)[component_index]), |
| &batch_component, index); |
| }; |
| if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) { |
| Status status; |
| mutex status_mu; |
| BlockingCounter counter(num_batch_elements); |
| const auto num_threads = params.runner_threadpool_size; |
| const auto slice_size = num_batch_elements / num_threads; |
| int64_t offset = 0; |
| for (size_t i = 0; i < num_threads; ++i) { |
| int64_t length = slice_size; |
| // When the number of threads does not divide the number of elements |
| // evenly, the size of some slices is incremented to guarantee their |
| // sizes add up to the total number of elements. |
| if (i < num_batch_elements % num_threads) ++length; |
| (*params.runner)([offset, length, &status, &status_mu, &counter, |
| ©_element_fn]() { |
| for (size_t j = offset; j < offset + length; ++j) { |
| { |
| Status s = copy_element_fn(j); |
| mutex_lock l(status_mu); |
| status.Update(s); |
| } |
| counter.DecrementCount(); |
| } |
| }); |
| offset += length; |
| } |
| counter.Wait(); |
| TF_RETURN_IF_ERROR(status); |
| } else { |
| for (size_t i = 0; i < num_batch_elements; ++i) { |
| TF_RETURN_IF_ERROR(copy_element_fn(i)); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) { |
| absl::flat_hash_set<tstring> configs; |
| const auto& autotune_options = options.autotune_options(); |
| std::vector<tstring> autotune_only_optimizations = { |
| kAutotuneBufferSizesOpt, |
| kBatchParallelizationOpt, |
| kDisablePrefetchLegacyAutotuneOpt, |
| kEnableGradientDescentOpt, |
| kMapParallelizationOpt, |
| kInjectPrefetchOpt, |
| kInjectPrefetchEligibleOpt}; |
| |
| if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled && |
| !autotune_options.enabled()) { |
| for (const auto& optimization : autotune_only_optimizations) { |
| configs.insert( |
| absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false")); |
| } |
| } else { |
| for (const auto& optimization : autotune_only_optimizations) { |
| configs.insert( |
| absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true")); |
| } |
| } |
| if (options.slack()) { |
| int num_devices = 1; |
| if (options.distribute_options().optional_num_devices_case() == |
| DistributeOptions::kNumDevices) { |
| num_devices = options.distribute_options().num_devices(); |
| } |
| configs.insert( |
| absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices)); |
| } |
| return configs; |
| } |
| |
| bool ShouldConfigureMaxIntraOpParallelism(const Options& options) { |
| return options.threading_options().optional_max_intra_op_parallelism_case() == |
| ThreadingOptions::kMaxIntraOpParallelism; |
| } |
| |
| bool ShouldUsePrivateThreadPool(const Options& options) { |
| return options.threading_options().optional_private_threadpool_size_case() == |
| ThreadingOptions::kPrivateThreadpoolSize; |
| } |
| |
| bool ShouldUseAutotuning(const Options& options) { |
| return options.autotune_options().optional_enabled_case() != |
| AutotuneOptions::kEnabled || |
| options.autotune_options().enabled(); |
| } |
| |
| bool ShouldApplyOptimizations( |
| const Options& options, |
| const absl::flat_hash_set<tstring>& optimizations_enabled, |
| const absl::flat_hash_set<tstring>& optimizations_default) { |
| return (options.optimization_options() |
| .optional_apply_default_optimizations_case() != |
| OptimizationOptions::kApplyDefaultOptimizations || |
| options.optimization_options().apply_default_optimizations() || |
| !optimizations_enabled.empty() || !optimizations_default.empty()); |
| } |
| |
| // static |
| void DatasetExperimentRegistry::Register(const string& experiment, |
| int64_t rollout_pct) { |
| mutex_lock l(*get_dataset_experiment_registry_lock()); |
| get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct)); |
| } |
| |
| // static |
| absl::flat_hash_map<string, int64_t> DatasetExperimentRegistry::Experiments() { |
| mutex_lock l(*get_dataset_experiment_registry_lock()); |
| return *get_dataset_experiments(); |
| } |
| |
| namespace { |
| |
| REGISTER_DATASET_EXPERIMENT("max_parallelism", 100); |
| REGISTER_DATASET_EXPERIMENT("max_parallelism_v2", 50); |
| REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism", 0); |
| REGISTER_DATASET_EXPERIMENT("inject_prefetch", 5); |
| } // namespace |
| } // namespace data |
| } // namespace tensorflow |