blob: 330a0989a584c2259281eb7b53644a63a15a757e [file] [log] [blame]
#include "device.h"
#include <stdexcept>
#include <string>
namespace torch {
Device::Device(DeviceType type, int64_t index, bool is_default)
: type(type), index(index), is_default(is_default) {
if (!is_default) {
switch (type) {
case DeviceType::CPU:
if (index != 0) {
throw std::runtime_error("cpu device index must be 0, got " + std::to_string(index));
}
break;
case DeviceType::CUDA:
if (index < 0) {
throw std::runtime_error("device index must be positive, got " + std::to_string(index));
}
break;
default:
throw std::runtime_error("unexpected DeviceType");
}
}
}
bool Device::operator==(const Device& rhs) {
return this->type == rhs.type && this->index == rhs.index && this->is_default == rhs.is_default;
}
}