blob: e0e7deb98e579c090c8fae320a3ba8a3ce8dbe5f [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace {
using tensorflow::int16;
using tensorflow::int32;
using tensorflow::int64;
using tensorflow::int8;
using tensorflow::uint16;
using tensorflow::uint32;
using tensorflow::uint64;
using tensorflow::uint8;
template <typename KeyType>
void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
std::sort(row_to_sort, row_to_sort + num_elements);
}
// For floating point numbers, we want a total order comparator. -NaN and NaN
// should appear at the beginning and end of the ordering, and -0.0 should
// appear before 0.0. Also we want to have a stable sort, so if the keys are the
// same, we compare the index values.
template <typename KeyType>
bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
bool lhs_is_negative = std::signbit(lhs);
bool rhs_is_negative = std::signbit(rhs);
// If the signs are different, we can just compare the signs.
if (lhs_is_negative != rhs_is_negative) {
return lhs_is_negative && !rhs_is_negative;
}
bool lhs_nan = std::isnan(lhs);
bool rhs_nan = std::isnan(rhs);
// Exactly one number is nan?
if (lhs_nan != rhs_nan) {
if (lhs_nan) {
return lhs_is_negative;
}
return !rhs_is_negative;
}
if (lhs != rhs) {
return lhs < rhs;
}
return lhs_index < rhs_index;
}
template <>
void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
std::sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<double, int64>& lhs,
const std::pair<double, int64>& rhs) -> bool {
return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
});
}
template <>
void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
std::sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<float, int64>& lhs,
const std::pair<float, int64>& rhs) -> bool {
return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
});
}
template <>
void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
int64 num_elements) {
std::sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<Eigen::half, int64>& lhs,
const std::pair<Eigen::half, int64>& rhs) -> bool {
return LessThan(
Eigen::half_impl::half_to_float(lhs.first), lhs.second,
Eigen::half_impl::half_to_float(rhs.first), rhs.second);
});
}
template <typename KeyType>
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
// High-level idea of the iteration/sorting logic:
// Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
// dimension to sort, c is the product of the more minor dimensions (set to 1
// if b is the most minor dimension), and a is the product of the more major
// dimensions (set to 1 if b is the most major dimension). There are a * c
// many rows that we need to sort. We iterate through these, calculate a
// 'base_offset' value which points to the first element in that row, and add
// i * c for accessing the 'i'-th element in that row.
int64 sort_dimension_elements = b;
int64 num_iteration_elements = a * c;
int64 sort_dimension_offset = c;
std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
new std::pair<KeyType, int64>[sort_dimension_elements]);
std::unique_ptr<std::string[]> reordered_values(
new std::string[sort_dimension_elements]);
for (int64 index = 0; index < num_iteration_elements; ++index) {
// 'index' can be split into two values which index into the 'c' dimension
// and the 'a' dimension, respectively. 'index' % 'c' is the index into the
// 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
// calculating the base offset, we need to multiply the index into the 'a'
// dimension with 'b' * 'c'.
// 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
int64 base_offset =
index % sort_dimension_offset +
(index - index % sort_dimension_offset) * sort_dimension_elements;
// TODO(b/26783907): We could define a custom iterator class that references
// both arrays. Then we could avoid the intermediate copy. However this
// would become more complicated, and it is not clear if the benefit is high
// enough.
for (int64 i = 0; i < sort_dimension_elements; ++i) {
row_to_sort[i] =
std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
}
KeyValueSort(row_to_sort.get(), sort_dimension_elements);
for (int64 i = 0; i < sort_dimension_elements; ++i) {
keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
}
if (values == nullptr) {
continue;
}
// Reorder the values according to the order defined by the keys.
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index =
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
values_primitive_type_size_in_bytes;
reordered_values[i] = std::string(values + memory_index,
values_primitive_type_size_in_bytes);
}
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index = (base_offset + i * sort_dimension_offset) *
values_primitive_type_size_in_bytes;
memcpy(values + memory_index, reordered_values[i].c_str(),
values_primitive_type_size_in_bytes);
}
}
}
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
bool* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
int8* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
uint8* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
int16* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
uint16* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
int32* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
uint32* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
float* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
int64* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
uint64* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
double* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
}