blob: 375d30c4cf3e7d98f7c24420ac751b527f0e20b3 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/device_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
namespace jit {
using xla::StatusOr;
void DeviceSet::Insert(DeviceId device_id) {
int word_index = device_id.id() / kWordSize;
int bit_index = device_id.id() % kWordSize;
if (word_index >= storage_.size()) {
storage_.resize(word_index + 1, 0);
}
storage_[word_index] |= (1ull << bit_index);
}
void DeviceSet::UnionWith(const DeviceSet& other) {
if (other.storage_.size() > storage_.size()) {
storage_.resize(other.storage_.size(), 0);
}
for (int i = 0; i < other.storage_.size(); i++) {
storage_[i] |= other.storage_[i];
}
}
bool DeviceSet::IsEmpty() const {
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
}
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
TF_RET_CHECK(!name.empty());
auto it = name_to_id_.find(name);
if (it != name_to_id_.end()) {
return it->second;
}
int new_id = names_.size();
names_.push_back(string(name));
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
DeviceType* device_type = id_to_device_type_.back().get();
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
name_to_id_.emplace(string(name), DeviceId(new_id));
const XlaOpRegistry::DeviceRegistration* compilation_device;
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
&compilation_device)) {
compilation_device = nullptr;
}
id_to_compilation_device_.push_back(compilation_device);
return DeviceId(new_id);
}
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
std::vector<string> names;
device_set.ForEach([&](DeviceId device_id) {
names.push_back(string(GetNameFor(device_id)));
return true;
});
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
}
} // namespace jit
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
bool failure_to_pick_is_error) {
#define FAILED_TO_PICK_DEVICE(failing_status) \
do { \
if (failure_to_pick_is_error) { \
return failing_status; \
} else { \
return {absl::nullopt}; \
} \
} while (false)
absl::optional<jit::DeviceId> maybe_gpu_device;
absl::optional<jit::DeviceId> maybe_cpu_device;
absl::optional<jit::DeviceId> maybe_unknown_device;
bool multiple_cpu_devices = false;
bool multiple_gpu_devices = false;
bool multiple_unknown_devices = false;
// Returns 'true' if d0 and d1 are conflicting devices. If they are
// compatible, update d1 with a more specific one.
// TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
const auto is_multiple_devices =
[&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
const absl::string_view name0 = device_info_cache.GetNameFor(d0);
const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
DeviceNameUtils::ParsedName parsed0, parsed1;
if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
!DeviceNameUtils::ParseFullName(name1, &parsed1) ||
!DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
return true;
}
if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
return false;
}
if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
*d1 = d0;
return false;
}
return true;
};
devices.ForEach([&](jit::DeviceId device) {
if (device_info_cache.IsGpu(device)) {
if (maybe_gpu_device) {
multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
if (multiple_gpu_devices) return false;
} else {
maybe_gpu_device = device;
}
} else if (device_info_cache.IsCpu(device)) {
if (maybe_cpu_device) {
multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
if (multiple_cpu_devices) return false;
} else {
maybe_cpu_device = device;
}
} else {
if (maybe_unknown_device) {
multiple_unknown_devices = true;
return false;
}
maybe_unknown_device = device;
}
return true;
});
if (multiple_cpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_gpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_unknown_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
}
if (maybe_unknown_device && maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and GPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_gpu_device)));
}
if (!allow_mixing_unknown_and_cpu) {
if (maybe_unknown_device && maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and CPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_cpu_device)));
}
}
if (maybe_gpu_device) {
return {*maybe_gpu_device};
} else if (maybe_unknown_device) {
return {*maybe_unknown_device};
} else if (maybe_cpu_device) {
return {*maybe_cpu_device};
}
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
#undef FAILED_TO_PICK_DEVICE
}
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/true));
return *device_id;
}
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
return PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/false);
}
} // namespace tensorflow