| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/kernels/range_sampler.h" |
| |
| #include <cmath> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/gtl/map_util.h" |
| #include "tensorflow/core/lib/io/inputbuffer.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| namespace tensorflow { |
| |
| using gtl::ArraySlice; |
| using gtl::MutableArraySlice; |
| |
| RangeSampler::~RangeSampler() {} |
| |
| void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique, |
| gtl::MutableArraySlice<int64> batch) const { |
| SampleBatchGetExpectedCount( |
| rnd, unique, batch, gtl::MutableArraySlice<float>(), |
| gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>()); |
| } |
| |
| void RangeSampler::SampleBatchGetExpectedCount( |
| random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch, |
| gtl::MutableArraySlice<float> batch_expected_count, |
| gtl::ArraySlice<int64> extras, |
| gtl::MutableArraySlice<float> extras_expected_count) const { |
| SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count, |
| extras, extras_expected_count, |
| gtl::ArraySlice<int64>()); |
| } |
| |
| namespace { |
| |
| // Approximates the expected count of a value in the output of SampleBatch. |
| // |
| // If unique=false, then this is (Probability(value) * batch_size) |
| // |
| // We use batch_size and num_tries, where num_tries is the observed number of |
| // tries it took to get batch_size unique values. |
| // |
| // Assuming (falsely) that the number of tries to get a batch of batch_size |
| // distinct values is _always_ num_tries, the probability that the value |
| // is in a batch is (1 - (1-p)^num_tries) |
| static float ExpectedCountHelper(float p, int batch_size, int num_tries) { |
| if (num_tries == batch_size) { |
| // This shortcut will always be taken if unique=false |
| return p * batch_size; |
| } |
| // numerically stable version of (1 - (1-p)^num_tries) |
| return -std::expm1(num_tries * std::log1p(-p)); |
| } |
| |
| } // namespace |
| |
| void RangeSampler::SampleBatchGetExpectedCountAvoid( |
| random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch, |
| MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras, |
| MutableArraySlice<float> extras_expected_count, |
| ArraySlice<int64> avoided_values) const { |
| const int batch_size = batch.size(); |
| int num_tries; |
| |
| if (unique) { |
| CHECK_LE(batch_size + avoided_values.size(), range_); |
| std::unordered_set<int64> used(batch_size); |
| used.insert(avoided_values.begin(), avoided_values.end()); |
| int num_picked = 0; |
| num_tries = 0; |
| while (num_picked < batch_size) { |
| num_tries++; |
| CHECK_LT(num_tries, kint32max); |
| int64 value = Sample(rnd); |
| if (gtl::InsertIfNotPresent(&used, value)) { |
| batch[num_picked++] = value; |
| } |
| } |
| } else { |
| CHECK_EQ(avoided_values.size(), size_t{0}) |
| << "avoided_values only supported with unique=true"; |
| for (int i = 0; i < batch_size; i++) { |
| batch[i] = Sample(rnd); |
| } |
| num_tries = batch_size; |
| } |
| // Compute the expected counts of the batch and the extra values |
| if (!batch_expected_count.empty()) { |
| CHECK_EQ(batch_size, batch_expected_count.size()); |
| for (int i = 0; i < batch_size; i++) { |
| batch_expected_count[i] = |
| ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries); |
| } |
| } |
| CHECK_EQ(extras.size(), extras_expected_count.size()); |
| for (size_t i = 0; i < extras.size(); i++) { |
| extras_expected_count[i] = |
| ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries); |
| } |
| } |
| |
| AllSampler::AllSampler(int64 range) : RangeSampler(range) {} |
| |
| void AllSampler::SampleBatchGetExpectedCountAvoid( |
| random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch, |
| MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras, |
| MutableArraySlice<float> extras_expected_count, |
| ArraySlice<int64> avoided_values) const { |
| const int batch_size = batch.size(); |
| CHECK_EQ(range_, batch_size); |
| for (int i = 0; i < batch_size; i++) { |
| batch[i] = i; |
| } |
| if (!batch_expected_count.empty()) { |
| CHECK_EQ(batch_size, batch_expected_count.size()); |
| for (int i = 0; i < batch_size; i++) { |
| batch_expected_count[i] = 1; |
| } |
| } |
| CHECK_EQ(size_t{0}, avoided_values.size()); |
| CHECK_EQ(extras.size(), extras_expected_count.size()); |
| for (size_t i = 0; i < extras.size(); i++) { |
| extras_expected_count[i] = 1; |
| } |
| } |
| |
| UniformSampler::UniformSampler(int64 range) |
| : RangeSampler(range), inv_range_(1.0 / range) {} |
| |
| int64 UniformSampler::Sample(random::SimplePhilox* rnd) const { |
| return rnd->Uniform64(range_); |
| } |
| |
| float UniformSampler::Probability(int64 value) const { return inv_range_; } |
| |
| LogUniformSampler::LogUniformSampler(int64 range) |
| : RangeSampler(range), log_range_(log1p(range)) {} |
| |
| int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const { |
| const int64 value = |
| static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1; |
| DCHECK_GE(value, 0); |
| // Mathematically, value should be <= range_, but might not be due to some |
| // floating point roundoff, so we mod by range_. In practice this case |
| // happens never regardless of the value of range_, including and up to |
| // DBL_MAX. But we include it as a guarantee of the function's output. |
| return value % range_; |
| } |
| |
| float LogUniformSampler::Probability(int64 value) const { |
| // value is returned iff the call to UniformDouble(log_range_) in the |
| // Sample() function returns a value between log(value + 1) |
| // and log(value + 2). The probability of this is: |
| // (log(value + 2) - log(value + 1)) / log_range |
| // To avoid two calls to log(), we compute this as follows: |
| return (log((value + 2.0) / (value + 1.0))) / log_range_; |
| } |
| |
| ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range) |
| : RangeSampler(range), picker_(range) { |
| CHECK_LT(range, kint32max); |
| } |
| |
| int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const { |
| return picker_.Pick(rnd); |
| } |
| |
| float ThreadUnsafeUnigramSampler::Probability(int64 value) const { |
| return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight(); |
| } |
| |
| void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) { |
| int num_updates = std::min(static_cast<int>(values.size()), |
| kint32max - picker_.total_weight()); |
| for (int i = 0; i < num_updates; i++) { |
| const int64 value = values[i]; |
| picker_.set_weight(value, picker_.get_weight(value) + 1); |
| } |
| } |
| |
| // Thread-safe unigram sampler |
| UnigramSampler::UnigramSampler(int64 range) |
| : RangeSampler(range), unsafe_sampler_(range) { |
| CHECK_LT(range, kint32max); |
| } |
| |
| int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const { |
| mutex_lock lock(mu_); // could use reader lock |
| return unsafe_sampler_.Sample(rnd); |
| } |
| |
| float UnigramSampler::Probability(int64 value) const { |
| mutex_lock lock(mu_); // could use reader lock |
| return unsafe_sampler_.Probability(value); |
| } |
| |
| // Overriding at a high level results in far fewer lock acquisitions. |
| void UnigramSampler::SampleBatchGetExpectedCountAvoid( |
| random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch, |
| MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras, |
| MutableArraySlice<float> extras_expected_count, |
| ArraySlice<int64> avoided_values) const { |
| mutex_lock lock(mu_); // could use reader lock |
| unsafe_sampler_.SampleBatchGetExpectedCountAvoid( |
| rnd, unique, batch, batch_expected_count, extras, extras_expected_count, |
| avoided_values); |
| } |
| |
| void UnigramSampler::Update(ArraySlice<int64> values) { |
| mutex_lock lock(mu_); |
| unsafe_sampler_.Update(values); |
| } |
| |
| FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range, |
| const string& vocab_file, |
| float distortion, |
| int32 num_reserved_ids, |
| int32 num_shards, int32 shard) |
| : RangeSampler(range), |
| total_weight_(0.0), |
| num_shards_(num_shards), |
| shard_(shard) { |
| FillReservedIds(num_reserved_ids); |
| // TODO(vanhoucke): make this non-crashing. |
| TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion)); |
| CHECK_EQ(range, weights_.size()); |
| dist_sampler_.reset(new random::DistributionSampler(weights_)); |
| } |
| |
| FixedUnigramSampler::FixedUnigramSampler(int64 range, |
| const std::vector<float>& unigrams, |
| float distortion, |
| int32 num_reserved_ids, |
| int32 num_shards, int32 shard) |
| : RangeSampler(range), |
| total_weight_(0.0), |
| num_shards_(num_shards), |
| shard_(shard) { |
| FillReservedIds(num_reserved_ids); |
| LoadFromUnigrams(unigrams, distortion); |
| // TODO(vanhoucke): make this non-crashing. |
| CHECK_EQ(range, weights_.size()); |
| dist_sampler_.reset(new random::DistributionSampler(weights_)); |
| } |
| |
| float FixedUnigramSampler::Probability(int64 value) const { |
| if (value < 0 || static_cast<size_t>(value) >= weights_.size()) { |
| return 0.0; |
| } |
| return weights_.at(value) / total_weight_; |
| } |
| |
| int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const { |
| return dist_sampler_->Sample(rnd); |
| } |
| |
| void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) { |
| for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) { |
| if (word_id % num_shards_ == shard_) weights_.push_back(0.0); |
| } |
| } |
| |
| Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, |
| float distortion) { |
| std::unique_ptr<RandomAccessFile> file; |
| TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); |
| |
| io::InputBuffer in(file.get(), 262144 /*bytes*/); |
| string line; |
| int32 word_id = weights_.size(); |
| while (in.ReadLine(&line).ok()) { |
| // The vocabulary file should be in csv like format, with the last |
| // field the weight associated with the word. |
| std::vector<string> cols = str_util::Split(line, ','); |
| if (cols.empty()) continue; |
| // Skip entries that do not belong to this shard. |
| if (word_id % num_shards_ == shard_) { |
| float w = 0.0; |
| if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) { |
| return errors::InvalidArgument("Wrong vocabulary format at line: ", |
| line); |
| } |
| w = std::pow(w, distortion); |
| total_weight_ += w; |
| weights_.push_back(w); |
| } |
| ++word_id; |
| } |
| return Status::OK(); |
| } |
| |
| void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams, |
| float distortion) { |
| int32 word_id = weights_.size(); |
| for (float w : unigrams) { |
| // Skip entries that do not belong to this shard. |
| if (word_id % num_shards_ == shard_) { |
| w = std::pow(w, distortion); |
| total_weight_ += w; |
| weights_.push_back(w); |
| } |
| ++word_id; |
| } |
| } |
| |
| } // namespace tensorflow |