| // Original TunableOp is from onnxruntime. |
| // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h |
| // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable |
| // Copyright (c) Microsoft Corporation. |
| // Licensed under the MIT license. |
| // |
| // Adapting TunableOp into PyTorch |
| // Copyright (c) Advanced Micro Devices, Inc. |
| // |
| #include <cuda_runtime.h> |
| |
| #include <ATen/cuda/CUDAContextLight.h> |
| #include <ATen/cuda/tunable/Tunable.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/StringUtil.h> |
| #include <torch/version.h> |
| |
| #ifndef _WIN32 |
| #include <cxxabi.h> |
| #endif |
| |
| #include <chrono> |
| #include <fstream> |
| #include <functional> |
| #include <limits> |
| #include <memory> |
| #include <mutex> |
| #include <sstream> |
| #include <string> |
| #include <thread> |
| #include <type_traits> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| // for validators |
| #ifdef USE_ROCM |
| #include <rocm-core/rocm_version.h> |
| #define ROCBLAS_BETA_FEATURES_API |
| #include <rocblas/rocblas.h> |
| #include <hipblaslt/hipblaslt.h> |
| #include <hipblaslt/hipblaslt-ext.hpp> |
| #endif |
| |
| namespace at::cuda::tunable { |
| |
| namespace { |
| |
| TuningContext tuning_context; |
| |
| } // anonymous namespace |
| |
| TuningContext* getTuningContext() { |
| return &tuning_context; |
| } |
| |
| std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) { |
| return stream << entry.key_ << "," << entry.time_; |
| } |
| |
| // TuningResultsManager |
| |
| KernelMap TuningResultsManager::Lookup(const std::string& op_signature) { |
| std::scoped_lock l{lock_}; |
| auto it = results_.find(op_signature); |
| if (it == results_.cend()) { |
| return {}; |
| } |
| return it->second; // copied |
| } |
| |
| ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const std::string& params_signature) { |
| std::scoped_lock l{lock_}; |
| auto kernel_map_it = results_.find(op_signature); |
| if (kernel_map_it == results_.cend()) { |
| TUNABLE_LOG3("missing op_signature, returning null ResultEntry for ", op_signature, ",", params_signature); |
| return ResultEntry::Null(); |
| } |
| |
| const auto& km = kernel_map_it->second; |
| auto it = km.find(params_signature); |
| if (it == km.cend()) { |
| TUNABLE_LOG3("missing params_signature, returning null ResultEntry for ", op_signature, ",", params_signature); |
| return ResultEntry::Null(); |
| } |
| TUNABLE_LOG3("ResultEntry found for ", op_signature, ",", params_signature); |
| return it->second; |
| } |
| |
| inline void TuningResultsManager::AddImpl(const std::string& op_signature, |
| const std::string& params_signature, |
| ResultEntry best, |
| KernelMap& kernel_map) { |
| auto it = kernel_map.find(params_signature); |
| if (it != kernel_map.end()) { |
| if (it->second != best) { |
| TUNABLE_LOG1(op_signature, "(", params_signature, ") already has a best kernel ", |
| "id=", it->second, " selected, want to add a different best kernel ", best, |
| ", the new kernel id will be ignored."); |
| } |
| return; |
| } |
| |
| TUNABLE_LOG2(op_signature, "(", params_signature, ") -> ", best); |
| kernel_map.emplace(params_signature, best); |
| } |
| |
| void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) { |
| std::scoped_lock l{lock_}; |
| |
| auto it = results_.find(op_signature); |
| if (it == results_.end()) { |
| it = results_.insert({op_signature, {}}).first; |
| } |
| |
| AddImpl(op_signature, params_signature, best, it->second); |
| } |
| |
| void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { |
| std::scoped_lock l{lock_}; |
| |
| auto it = results_.find(op_signature); |
| if (it == results_.end()) { |
| return; |
| } |
| |
| auto it2 = it->second.find(params_signature); |
| if (it2 == it->second.end()) { |
| return; |
| } |
| |
| TUNABLE_LOG2(op_signature, "(", params_signature, ")"); |
| it->second.erase(it2); |
| } |
| |
| inline void TuningResultsManager::DisjointMergeImpl( |
| const std::string& op_signature, |
| const KernelMap& kernel_map, |
| /*out*/ std::unordered_map<std::string, KernelMap>& results) { |
| auto it = results.find(op_signature); |
| if (it == results.end()) { |
| for (const auto& [param_sig, kernel_id] : kernel_map) { |
| TUNABLE_LOG2(op_signature, "(", param_sig, ") -> ", kernel_id); |
| } |
| results[op_signature] = kernel_map; |
| return; |
| } |
| |
| for (const auto& [params_signature, best] : kernel_map) { |
| AddImpl(op_signature, params_signature, best, it->second); |
| } |
| } |
| |
| void TuningResultsManager::Load(const std::unordered_map<std::string, KernelMap>& results_to_load) { |
| TUNABLE_LOG1("Loading results"); |
| std::scoped_lock l{lock_}; |
| for (const auto& [op_signature, kernel_map] : results_to_load) { |
| DisjointMergeImpl(op_signature, kernel_map, results_); |
| } |
| } |
| |
| ResultsMap TuningResultsManager::Dump() { |
| std::scoped_lock l{lock_}; |
| return results_; |
| } |
| |
| void TuningResultsManager::DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map) { |
| std::scoped_lock l{lock_}; |
| DisjointMergeImpl(op_signature, kernel_map, results_); |
| } |
| |
| size_t TuningResultsManager::GetSize() { |
| size_t size = 0; |
| std::scoped_lock l{lock_}; |
| for (const auto& [op_signature, kernel_map] : results_) { |
| size += kernel_map.size(); |
| } |
| return size; |
| } |
| |
| // TuningResultsValidator |
| |
| TuningResultsValidator::TuningResultsValidator() { |
| RegisterValidator( |
| "PT_VERSION", |
| [this]() { return GetPyTorchVersion(); }, |
| [this](auto&& k) { return ValidatePyTorchVersion(std::forward<decltype(k)>(k)); }); |
| #ifdef USE_ROCM |
| // rocm |
| { |
| std::string rocm_version = ROCM_BUILD_INFO; |
| RegisterValidator( |
| "ROCM_VERSION", |
| [rocm_version]() { return rocm_version; }, |
| [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); |
| } |
| // gfx arch |
| { |
| std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; |
| RegisterValidator( |
| "GCN_ARCH_NAME", |
| [gcn_arch_name]() { return gcn_arch_name; }, |
| [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); |
| } |
| // rocblas |
| { |
| #define STRINGIFY(s) #s |
| #define XSTRINGIFY(s) STRINGIFY(s) |
| std::string rocblas_version = c10::str( |
| XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", |
| XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", |
| XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", |
| XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); |
| #undef XSTRINGIFY |
| #undef STRINGIFY |
| RegisterValidator( |
| "ROCBLAS_VERSION", |
| [rocblas_version]() { return rocblas_version; }, |
| [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); |
| } |
| // hipblaslt |
| { |
| int version; |
| std::string revision(128, '\0'); |
| auto handle = at::cuda::getCurrentCUDABlasLtHandle(); |
| hipblasLtGetVersion(handle, &version); |
| hipblasLtGetGitRevision(handle, revision.data()); |
| std::string hipblaslt_version = |
| c10::str(version, "-", revision.c_str()); |
| RegisterValidator( |
| "HIPBLASLT_VERSION", |
| [hipblaslt_version]() { return hipblaslt_version; }, |
| [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); |
| } |
| #endif |
| } |
| |
| std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const { |
| std::unordered_map<std::string, std::string> ret; |
| for (const auto& [key, get_validate_func_pair] : validators_) { |
| const GetFunc& getter = get_validate_func_pair.first; |
| ret[key] = getter(); |
| } |
| return ret; |
| } |
| |
| static bool CheckMandatoryKeys( |
| const TuningResultsValidator::GetValidateFuncs& gv_funcs, |
| const std::unordered_map<std::string, std::string>& to_check) { |
| bool passed = true; |
| for (const auto& k : TuningResultsValidator::mandatory_keys) { |
| if (gv_funcs.find(k) == gv_funcs.end()) { |
| passed = false; |
| TUNABLE_LOG1("key=\"", k, "\" is not registered for Get and Validate. "); |
| } |
| |
| if (to_check.find(k) == to_check.end()) { |
| passed = false; |
| TUNABLE_LOG1("key=\"", k, "\" is not provided for validation. "); |
| } |
| } |
| return passed; |
| } |
| |
| static bool CheckKeysMatching( |
| const TuningResultsValidator::GetValidateFuncs& gv_funcs, |
| const std::unordered_map<std::string, std::string>& to_check) { |
| auto get_keys = [](const auto& it) -> std::string { return it.first; }; |
| std::vector<std::string> required_keys; |
| std::vector<std::string> provided_keys; |
| std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys); |
| std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys); |
| std::sort(required_keys.begin(), required_keys.end()); |
| std::sort(provided_keys.begin(), provided_keys.end()); |
| |
| std::unordered_set<std::string> intersection; |
| std::set_intersection(required_keys.cbegin(), required_keys.cend(), |
| provided_keys.cbegin(), provided_keys.cend(), |
| std::inserter(intersection, intersection.end())); |
| bool matched = true; |
| if (intersection.size() != required_keys.size()) { |
| matched = false; |
| for (const auto& k : required_keys) { |
| if (intersection.find(k) == intersection.end()) { |
| TORCH_WARN("Unmatched validator: \"", k, "\" is required, but the tuning results does not provide it. "); |
| } |
| } |
| } |
| if (intersection.size() != provided_keys.size()) { |
| matched = false; |
| for (const auto& k : provided_keys) { |
| if (intersection.find(k) == intersection.end()) { |
| TORCH_WARN("Unmatched validator: \"", k, "\" is provided, but pytorch is unable to consume it. "); |
| } |
| } |
| } |
| return matched; |
| } |
| |
| TuningStatus TuningResultsValidator::ValidateAll( |
| const std::unordered_map<std::string, std::string>& to_validate) const { |
| if (!CheckMandatoryKeys(validators_, to_validate)) { |
| return FAIL; |
| } |
| if (!CheckKeysMatching(validators_, to_validate)) { |
| return FAIL; |
| } |
| |
| for (const auto& [key, value] : to_validate) { |
| const auto& it = validators_.find(key); |
| if (it == validators_.cend()) { |
| TORCH_WARN("Failed to lookup validator using key ", key); |
| for (const auto& [key2, val2] : validators_) { |
| TORCH_WARN("available key ", key2); |
| } |
| return FAIL; |
| } |
| const ValidateFunc& validator = it->second.second; |
| if (validator(value) != OK) { |
| TORCH_WARN("Failed validator: ", key); |
| return FAIL; |
| } |
| } |
| |
| return OK; |
| } |
| |
| void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) { |
| if (validators_.find(key) != validators_.end()) { |
| TORCH_WARN("Attempting to re-register validator with key ", key); |
| } |
| else { |
| validators_[key] = std::make_pair(gf, vf); |
| } |
| } |
| |
| std::string TuningResultsValidator::GetPyTorchVersion() const { |
| return TORCH_VERSION; |
| } |
| |
| TuningStatus TuningResultsValidator::ValidatePyTorchVersion(const std::string& value) const { |
| TUNABLE_LOG1("PT_VERSION validation: expect ", value, " to match ", GetPyTorchVersion()); |
| if (value == GetPyTorchVersion()) { |
| return OK; |
| } |
| return FAIL; |
| } |
| |
| // TuningContext |
| |
| TuningContext::TuningContext() : |
| enable_{false}, |
| tuning_enable_{true}, |
| manager_initialized_{false}, |
| write_file_on_exit_{true}, |
| numerics_check_enable_{false}, |
| max_tuning_duration_ms_{30}, |
| max_tuning_iterations_{100}, |
| max_warmup_duration_ms_{0}, |
| max_warmup_iterations_{0}, |
| icache_flush_{true}, |
| rotating_buffer_size_{-1}, |
| filename_{}, |
| results_count_from_input_file_{0} |
| { |
| } |
| |
| TuningContext::~TuningContext() { |
| if (!manager_initialized_) { |
| // TuningResultsManager was never initialized, no tuning requested or performed. |
| // This can happen in a DDP job where a python process spawns other workers |
| // but doesn't do any computation itself. |
| return; |
| } |
| auto filename = GetFilename(); |
| if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) { |
| if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) { |
| if (results_count_from_input_file_ > 0) { |
| TUNABLE_LOG1("additional tuning results available, rewriting file ", filename); |
| } |
| else { |
| TUNABLE_LOG1("writing file ", filename); |
| } |
| if (!WriteFile(filename)) { |
| TUNABLE_LOG1("failed to write file ", filename); |
| } |
| } |
| } |
| } |
| |
| void TuningContext::EnableTunableOp(bool value) { |
| enable_ = value; |
| if (value) { |
| TUNABLE_LOG1("Enable TunableOp"); |
| } |
| else { |
| TUNABLE_LOG1("Disable TunableOp"); |
| } |
| } |
| |
| bool TuningContext::IsTunableOpEnabled() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_ENABLED"); |
| if (env != nullptr && strcmp(env, "1") == 0) { |
| return true; |
| } |
| return enable_; |
| } |
| |
| void TuningContext::EnableTuning(bool value) { |
| tuning_enable_ = value; |
| if (value) { |
| TUNABLE_LOG1("Enable Tuning for TunableOp"); |
| } |
| else { |
| TUNABLE_LOG1("Disable Tuning for TunableOp"); |
| } |
| } |
| |
| bool TuningContext::IsTuningEnabled() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING"); |
| if (env != nullptr && strcmp(env, "0") == 0) { |
| return false; |
| } |
| return tuning_enable_; |
| } |
| |
| void TuningContext::WriteFileOnExit(bool value) { |
| write_file_on_exit_ = value; |
| } |
| |
| void TuningContext::EnableNumericsCheck(bool value) { |
| numerics_check_enable_ = value; |
| } |
| |
| bool TuningContext::IsNumericsCheckEnabled() const { |
| const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); |
| if (env != nullptr && strcmp(env, "1") == 0) { |
| return true; |
| } |
| return numerics_check_enable_; |
| } |
| |
| void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) { |
| max_tuning_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms; |
| } |
| |
| int TuningContext::GetMaxTuningDurationMs() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"); |
| if (env != nullptr) { |
| int val = atoi(env); |
| return val < 0 ? 0 : val; |
| } |
| return max_tuning_duration_ms_; |
| } |
| |
| void TuningContext::SetMaxTuningIterations(int max_iter) { |
| max_tuning_iterations_ = max_iter < 0 ? 0 : max_iter; |
| } |
| |
| int TuningContext::GetMaxTuningIterations() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"); |
| if (env != nullptr) { |
| int val = atoi(env); |
| return val < 0 ? 0 : val; |
| } |
| return max_tuning_iterations_; |
| } |
| |
| void TuningContext::SetMaxWarmupDurationMs(int max_duration_ms) { |
| max_warmup_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms; |
| } |
| |
| int TuningContext::GetMaxWarmupDurationMs() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"); |
| if (env != nullptr) { |
| int val = atoi(env); |
| return val < 0 ? 0 : val; |
| } |
| return max_warmup_duration_ms_; |
| } |
| |
| void TuningContext::SetMaxWarmupIterations(int max_iter) { |
| max_warmup_iterations_ = max_iter < 0 ? 0 : max_iter; |
| } |
| |
| int TuningContext::GetMaxWarmupIterations() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS"); |
| if (env != nullptr) { |
| int val = atoi(env); |
| return val < 0 ? 0 : val; |
| } |
| return max_warmup_iterations_; |
| } |
| |
| void TuningContext::EnableICacheFlush(bool value) { |
| icache_flush_ = value; |
| } |
| |
| bool TuningContext::IsICacheFlushEnabled() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED"); |
| if (env != nullptr && strcmp(env, "0") == 0) { |
| return false; |
| } |
| return icache_flush_; |
| } |
| |
| void TuningContext::SetRotatingBufferSize(int size) { |
| rotating_buffer_size_ = size < 0 ? 0 : size; |
| } |
| |
| int TuningContext::GetRotatingBufferSize() const { |
| static const char *env = std::getenv("PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"); |
| if (env != nullptr) { |
| constexpr int MB = 1024 * 1024; |
| int val = atoi(env); |
| return val < 0 ? 0 : val * MB; // env var is specified as MB, returned as bytes |
| } |
| else { |
| if (rotating_buffer_size_ < 0) { |
| // negative buffer size (default) means query for L2 cache size |
| int l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; |
| return l2_cache_size; |
| } |
| else { |
| return rotating_buffer_size_; |
| } |
| } |
| } |
| |
| TuningResultsManager& TuningContext::GetTuningResultsManager() { |
| c10::call_once(manager_init_once_, [this]() { |
| manager_initialized_ = true; |
| if (GetFilename().empty()) { |
| // if SetFilename() was not already called, call it now with the default or env var |
| const char *env = std::getenv("PYTORCH_TUNABLEOP_FILENAME"); |
| std::string filename = (env == nullptr) ? "tunableop_results.csv" : env; |
| SetFilename(filename, true); |
| } |
| auto filename = GetFilename(); |
| if (!filename.empty()) { |
| ReadFile(filename); |
| // attempt immediately to open file for writing to catch errors early |
| std::ofstream file(filename, std::ios::out | std::ios::app); |
| if (!file.good()) { |
| TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved"); |
| } |
| } |
| }); |
| return manager_; |
| } |
| |
| TuningResultsValidator& TuningContext::GetTuningResultsValidator() { |
| return validator_; |
| } |
| |
| TuningResults TuningContext::GetTuningResults() { |
| TuningResults tr; |
| tr.validators = GetTuningResultsValidator().GetAllValidators(); |
| tr.results = GetTuningResultsManager().Dump(); |
| return tr; |
| } |
| |
| TuningStatus TuningContext::LoadTuningResults(const TuningResults& tr) { |
| TORCH_CHECK(GetTuningResultsValidator().ValidateAll(tr.validators)); |
| GetTuningResultsManager().Load(tr.results); |
| return OK; |
| } |
| |
| void TuningContext::SetFilename(const std::string& filename, bool insert_device_ordinal) { |
| filename_ = filename; |
| |
| if (filename_.empty()) { |
| return; |
| } |
| |
| if (insert_device_ordinal) { |
| // differentiate filename based on device ordinal to avoid |
| // use case of one process per device writing to same file |
| std::string device = c10::str(int(c10::cuda::current_device())); |
| |
| // does filename contain %d to insert device ordinal in specific location? |
| const std::string TOKEN("%d"); |
| std::size_t found = filename_.find(TOKEN); |
| if (found != std::string::npos) { |
| filename_.replace(found, TOKEN.length(), device); |
| } |
| else { |
| // no %d present, so append device ordinal before final '.' |
| found = filename_.rfind('.'); |
| if (found != std::string::npos) { |
| filename_.insert(found, device); |
| } |
| else { |
| // all else fails, just append |
| filename_.append(device); |
| } |
| } |
| } |
| } |
| |
| std::string TuningContext::GetFilename() const { |
| return filename_; |
| } |
| |
| bool TuningContext::ReadFile(const std::string& filename_) { |
| std::string filename = filename_.empty() ? GetFilename() : filename_; |
| TUNABLE_LOG1("reading tuning results from ", filename); |
| ResultsMap results; |
| std::unordered_map<std::string, std::string> validators; |
| std::string line; |
| std::ifstream file(filename); |
| if (!file) { |
| TUNABLE_LOG1("could not open ", filename, " for reading tuning results"); |
| return false; |
| } |
| while (std::getline(file, line)) { |
| if (line.empty()) { |
| continue; |
| } |
| std::string part; |
| std::vector<std::string> parts; |
| std::stringstream line_as_stream(line); |
| while (std::getline(line_as_stream, part, ',')) { |
| parts.push_back(part); |
| } |
| if (parts[0] == "Validator" && parts.size() >= 3) { |
| validators[parts[1]] = parts[2]; |
| TUNABLE_LOG1("Validator ", parts[1], "=", parts[2]); |
| } |
| else if (parts.size() >= 4) { |
| results[parts[0]].emplace(parts[1], ResultEntry(parts[2], atof(parts[3].c_str()))); |
| } |
| else if (parts.size() >= 3) { |
| // the timestamp from the file is optional |
| results[parts[0]].emplace(parts[1], ResultEntry(parts[2], 0)); |
| } |
| else { |
| TUNABLE_LOG1("could not parse line: ", line); |
| } |
| } |
| if (GetTuningResultsValidator().ValidateAll(validators) != FAIL) { |
| manager_.Load(results); |
| results_count_from_input_file_ = manager_.GetSize(); |
| } |
| else { |
| TUNABLE_LOG1("results validator check failed"); |
| return false; |
| } |
| return true; |
| } |
| |
| bool TuningContext::WriteFile(const std::string& filename_) { |
| std::string filename = filename_.empty() ? GetFilename() : filename_; |
| std::ofstream file(filename, std::ios::out | std::ios::trunc); |
| if (!file.good()) { |
| TUNABLE_LOG1("error opening tuning results file for writing ", filename); |
| return false; |
| } |
| auto validators = GetTuningResultsValidator().GetAllValidators(); |
| for (const auto& [key, val] : validators) { |
| file << "Validator," << key << "," << val << std::endl; |
| } |
| auto results = GetTuningResultsManager().Dump(); |
| for (const auto& [op_sig, kernelmap] : results) { |
| for (const auto& [param_sig, result] : kernelmap) { |
| file << op_sig << "," << param_sig << "," << result << std::endl; |
| } |
| } |
| file.close(); |
| return true; |
| } |
| |
| } // namespace at::cuda::tunable |