| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // See docs in ../ops/sdca_ops.cc. |
| |
| #define EIGEN_USE_THREADS |
| |
| #include <stdint.h> |
| #include <atomic> |
| #include <limits> |
| #include <memory> |
| #include <new> |
| #include <string> |
| #include <vector> |
| |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/core/framework/device_base.h" |
| #include "tensorflow/core/framework/kernel_def_builder.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_def_builder.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/tensor_types.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/kernels/hinge-loss.h" |
| #include "tensorflow/core/kernels/logistic-loss.h" |
| #include "tensorflow/core/kernels/loss.h" |
| #include "tensorflow/core/kernels/poisson-loss.h" |
| #include "tensorflow/core/kernels/sdca_internal.h" |
| #include "tensorflow/core/kernels/smooth-hinge-loss.h" |
| #include "tensorflow/core/kernels/squared-loss.h" |
| #include "tensorflow/core/lib/core/coding.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/fingerprint.h" |
| #include "tensorflow/core/platform/macros.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/util/work_sharder.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| using sdca::Example; |
| using sdca::Examples; |
| using sdca::ExampleStatistics; |
| using sdca::ModelWeights; |
| using sdca::Regularizations; |
| |
| struct ComputeOptions { |
| explicit ComputeOptions(OpKernelConstruction* const context) { |
| string loss_type; |
| OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type)); |
| if (loss_type == "logistic_loss") { |
| loss_updater.reset(new LogisticLossUpdater); |
| } else if (loss_type == "squared_loss") { |
| loss_updater.reset(new SquaredLossUpdater); |
| } else if (loss_type == "hinge_loss") { |
| loss_updater.reset(new HingeLossUpdater); |
| } else if (loss_type == "smooth_hinge_loss") { |
| loss_updater.reset(new SmoothHingeLossUpdater); |
| } else if (loss_type == "poisson_loss") { |
| loss_updater.reset(new PoissonLossUpdater); |
| } else { |
| OP_REQUIRES( |
| context, false, |
| errors::InvalidArgument("Unsupported loss type: ", loss_type)); |
| } |
| auto s = context->GetAttr("adaptative", &adaptive); |
| if (!s.ok()) { |
| s = context->GetAttr("adaptive", &adaptive); |
| } |
| OP_REQUIRES_OK(context, s); |
| OP_REQUIRES_OK( |
| context, context->GetAttr("num_sparse_features", &num_sparse_features)); |
| OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values", |
| &num_sparse_features_with_values)); |
| OP_REQUIRES_OK(context, |
| context->GetAttr("num_dense_features", &num_dense_features)); |
| OP_REQUIRES( |
| context, num_sparse_features + num_dense_features > 0, |
| errors::InvalidArgument("Requires at least one feature to train.")); |
| |
| OP_REQUIRES(context, |
| static_cast<int64>(num_sparse_features) + |
| static_cast<int64>(num_dense_features) <= |
| std::numeric_limits<int>::max(), |
| errors::InvalidArgument( |
| strings::Printf("Too many feature groups: %lld > %d", |
| static_cast<int64>(num_sparse_features) + |
| static_cast<int64>(num_dense_features), |
| std::numeric_limits<int>::max()))); |
| OP_REQUIRES_OK( |
| context, context->GetAttr("num_loss_partitions", &num_loss_partitions)); |
| OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations", |
| &num_inner_iterations)); |
| OP_REQUIRES_OK(context, regularizations.Initialize(context)); |
| } |
| |
| std::unique_ptr<DualLossUpdater> loss_updater; |
| int num_sparse_features = 0; |
| int num_sparse_features_with_values = 0; |
| int num_dense_features = 0; |
| int num_inner_iterations = 0; |
| int num_loss_partitions = 0; |
| bool adaptive = true; |
| Regularizations regularizations; |
| }; |
| |
| // TODO(shengx): The helper classes/methods are changed to support multiclass |
| // SDCA, which lead to changes within this function. Need to revisit the |
| // convergence once the multiclass SDCA is in. |
| void DoCompute(const ComputeOptions& options, OpKernelContext* const context) { |
| ModelWeights model_weights; |
| OP_REQUIRES_OK(context, model_weights.Initialize(context)); |
| |
| Examples examples; |
| OP_REQUIRES_OK( |
| context, |
| examples.Initialize(context, model_weights, options.num_sparse_features, |
| options.num_sparse_features_with_values, |
| options.num_dense_features)); |
| |
| const Tensor* example_state_data_t; |
| OP_REQUIRES_OK(context, |
| context->input("example_state_data", &example_state_data_t)); |
| TensorShape expected_example_state_shape({examples.num_examples(), 4}); |
| OP_REQUIRES(context, |
| example_state_data_t->shape() == expected_example_state_shape, |
| errors::InvalidArgument( |
| "Expected shape ", expected_example_state_shape.DebugString(), |
| " for example_state_data, got ", |
| example_state_data_t->shape().DebugString())); |
| |
| Tensor mutable_example_state_data_t(*example_state_data_t); |
| auto example_state_data = mutable_example_state_data_t.matrix<float>(); |
| OP_REQUIRES_OK(context, context->set_output("out_example_state_data", |
| mutable_example_state_data_t)); |
| |
| if (options.adaptive) { |
| OP_REQUIRES_OK(context, |
| examples.SampleAdaptiveProbabilities( |
| options.num_loss_partitions, options.regularizations, |
| model_weights, example_state_data, options.loss_updater, |
| /*num_weight_vectors =*/1)); |
| } else { |
| examples.RandomShuffle(); |
| } |
| struct { |
| mutex mu; |
| Status value GUARDED_BY(mu); |
| } train_step_status; |
| std::atomic<std::int64_t> atomic_index(-1); |
| auto train_step = [&](const int64 begin, const int64 end) { |
| // The static_cast here is safe since begin and end can be at most |
| // num_examples which is an int. |
| for (int id = static_cast<int>(begin); id < end; ++id) { |
| const int64 example_index = examples.sampled_index(++atomic_index); |
| const Example& example = examples.example(example_index); |
| const float dual = example_state_data(example_index, 0); |
| const float example_weight = example.example_weight(); |
| float example_label = example.example_label(); |
| const Status conversion_status = |
| options.loss_updater->ConvertLabel(&example_label); |
| if (!conversion_status.ok()) { |
| mutex_lock l(train_step_status.mu); |
| train_step_status.value = conversion_status; |
| // Return from this worker thread - the calling thread is |
| // responsible for checking context status and returning on error. |
| return; |
| } |
| |
| // Compute wx, example norm weighted by regularization, dual loss, |
| // primal loss. |
| // For binary SDCA, num_weight_vectors should be one. |
| const ExampleStatistics example_statistics = |
| example.ComputeWxAndWeightedExampleNorm( |
| options.num_loss_partitions, model_weights, |
| options.regularizations, 1 /* num_weight_vectors */); |
| |
| const double new_dual = options.loss_updater->ComputeUpdatedDual( |
| options.num_loss_partitions, example_label, example_weight, dual, |
| example_statistics.wx[0], example_statistics.normalized_squared_norm); |
| |
| // Compute new weights. |
| const double normalized_bounded_dual_delta = |
| (new_dual - dual) * example_weight / |
| options.regularizations.symmetric_l2(); |
| model_weights.UpdateDeltaWeights( |
| context->eigen_cpu_device(), example, |
| std::vector<double>{normalized_bounded_dual_delta}); |
| |
| // Update example data. |
| example_state_data(example_index, 0) = new_dual; |
| example_state_data(example_index, 1) = |
| options.loss_updater->ComputePrimalLoss( |
| example_statistics.prev_wx[0], example_label, example_weight); |
| example_state_data(example_index, 2) = |
| options.loss_updater->ComputeDualLoss(dual, example_label, |
| example_weight); |
| example_state_data(example_index, 3) = example_weight; |
| } |
| }; |
| // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data, |
| // number of cpus, and cost per example. |
| const int64 kCostPerUnit = examples.num_features(); |
| const DeviceBase::CpuWorkerThreads& worker_threads = |
| *context->device()->tensorflow_cpu_worker_threads(); |
| |
| Shard(worker_threads.num_threads, worker_threads.workers, |
| examples.num_examples(), kCostPerUnit, train_step); |
| mutex_lock l(train_step_status.mu); |
| OP_REQUIRES_OK(context, train_step_status.value); |
| } |
| |
| } // namespace |
| |
| class SdcaOptimizer : public OpKernel { |
| public: |
| explicit SdcaOptimizer(OpKernelConstruction* const context) |
| : OpKernel(context), options_(context) {} |
| |
| void Compute(OpKernelContext* context) override { |
| DoCompute(options_, context); |
| } |
| |
| private: |
| // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and |
| // template the entire class to avoid the virtual table lookup penalty in |
| // the inner loop. |
| ComputeOptions options_; |
| }; |
| REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU), |
| SdcaOptimizer); |
| REGISTER_KERNEL_BUILDER(Name("SdcaOptimizerV2").Device(DEVICE_CPU), |
| SdcaOptimizer); |
| |
| class SdcaShrinkL1 : public OpKernel { |
| public: |
| explicit SdcaShrinkL1(OpKernelConstruction* const context) |
| : OpKernel(context) { |
| OP_REQUIRES_OK(context, regularizations_.Initialize(context)); |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| OpMutableInputList weights_inputs; |
| OP_REQUIRES_OK(context, |
| context->mutable_input_list("weights", &weights_inputs)); |
| |
| auto do_work = [&](const int64 begin, const int64 end) { |
| for (int i = begin; i < end; ++i) { |
| auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>(); |
| prox_w.device(context->eigen_cpu_device()) = |
| regularizations_.EigenShrinkVector(prox_w); |
| } |
| }; |
| |
| if (weights_inputs.size() > 0) { |
| int64 num_weights = 0; |
| for (int i = 0; i < weights_inputs.size(); ++i) { |
| num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements(); |
| } |
| // TODO(sibyl-Aix6ihai): Tune this value. |
| const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size(); |
| const DeviceBase::CpuWorkerThreads& worker_threads = |
| *context->device()->tensorflow_cpu_worker_threads(); |
| Shard(worker_threads.num_threads, worker_threads.workers, |
| weights_inputs.size(), kCostPerUnit, do_work); |
| } |
| } |
| |
| private: |
| Regularizations regularizations_; |
| }; |
| REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1); |
| |
| // Computes platform independent, compact and unique (with very high |
| // probability) representation of an example id. It shouldn't be put in |
| // persistent storage, as its implementation may change in the future. |
| // |
| // The current probability of at least one collision for 1B example_ids is |
| // approximately 10^-21 (ie 2^60 / 2^129). |
| class SdcaFprint : public OpKernel { |
| public: |
| explicit SdcaFprint(OpKernelConstruction* const context) |
| : OpKernel(context) {} |
| |
| void Compute(OpKernelContext* context) override { |
| const Tensor& input = context->input(0); |
| OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), |
| errors::InvalidArgument("Input must be a vector, got shape ", |
| input.shape().DebugString())); |
| Tensor* out; |
| const int64 num_elements = input.NumElements(); |
| OP_REQUIRES_OK(context, context->allocate_output( |
| 0, TensorShape({num_elements, 2}), &out)); |
| |
| const auto in_values = input.flat<tstring>(); |
| auto out_values = out->matrix<int64>(); |
| |
| for (int64 i = 0; i < num_elements; ++i) { |
| const Fprint128 fprint = Fingerprint128(in_values(i)); |
| // Never return 0 or 1 as the first value of the hash to allow these to |
| // safely be used as sentinel values (e.g. dense hash table empty key). |
| out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2) |
| ? fprint.low64 |
| : fprint.low64 + ~static_cast<uint64>(1); |
| out_values(i, 1) = fprint.high64; |
| } |
| } |
| }; |
| REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint); |
| |
| } // namespace tensorflow |