[ROCm] TunableOp logging improvements (#132173)
Summary:
TunableOp logging improvements:
1. PYTORCH_TUNABLEOP_VERBOSE=1: print out the expected value vs actual value for TunableOp validators, so that if validation fails, we know exactly how to fix it
2. PYTORCH_TUNABLEOP_VERBOSE=3: print out the exact kernel signature for both successful and failure cases in kernel lookup
Test Plan:
> PYTORCH_TUNABLEOP_VERBOSE=3 buck
2 run mode/{opt,amd-gpu} -c fbcode.enable_gpu_sections=true //scripts/xdwang/example:fc_llama -- --enab
le-tuning
```
reading tuning results from hipblas_tuning_pt_llama0.csv
Validator PT_VERSION=2.5.0
Validator ROCBLAS_VERSION=4.0.0-72e57364-dirty
Validator HIPBLASLT_VERSION=800-a15e4178
Validator ROCM_VERSION=6.0.0.0-12969-1544e39
Validator GCN_ARCH_NAME=gfx942:sramecc+:xnack-
GCN_ARCH_NAME validation: expect gfx942:sramecc+:xnack- to match gfx942:sramecc+:xnack-
ROCM_VERSION validation: expect 6.0.0.0-12969-1544e39 to match 6.0.0.0-12969-1544e39
HIPBLASLT_VERSION validation: expect 800-a15e4178 to match 800-a15e4178
ROCBLAS_VERSION validation: expect 4.0.0-72e57364-dirty to match 4.0.0-72e57364-dirty
PT_VERSION validation: expect 2.5.0 to match 2.5.0
Loading results
GemmTunableOp_BFloat16_TN(tn_8192_2_1024) -> Gemm_Hipblaslt_TN_61169,0.0171694
GemmTunableOp_BFloat16_TN(tn_7168_2_8192) -> Gemm_Hipblaslt_TN_61089,0.036138
GemmTunableOp_BFloat16_TN(tn_8192_2_3584) -> Gemm_Hipblaslt_TN_61169,0.0240673
missing params_signature, returning null ResultEntry for GemmTunableOp_BFloat16_TN,tn_1280_2_8192
finding fastest for GemmTunableOp_BFloat16_TN(tn_1280_2_8192) out of 2818 candidates
Rotating buffer 4 MiB. Needed Size: 20 MiB. Needed number of param copies: 1
├──tuning using warmup iters 0 [0 ms] and tuning iters 1 [0.208254 ms] instance id=0, GemmTunableOp_BFloat16_TN(tn_1280_2_8192) Default
├──offset at 3
......
ResultEntry found for GemmTunableOp_BFloat16_TN,tn_8192_2_3584
ResultEntry found for GemmTunableOp_BFloat16_TN,tn_8192_2_3584
ResultEntry found for GemmTunableOp_BFloat16_TN,tn_8192_2_3584
Avg time: 16.42832040786743 us, Achieved 7.15 TFLOPS, 3578.07 GB/s
2x1280x8192-torch.bfloat16,16.260499954223633,2.5794434438103107,1294.0669757533708
2x8192x1024-torch.bfloat16,16.15394949913025,2.0771658350056508,1041.11852032876
2x7168x8192-torch.bfloat16,25.691540241241455,9.14234887416194,4574.841325057144
2x8192x3584-torch.bfloat16,16.42832040786743,7.1486621324818085,3578.0709494714856
```
Differential Revision: D60468273
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132173
Approved by: https://github.com/mxz297, https://github.com/jeffdaily, https://github.com/eqy
diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp
index 49b90e9..52ec306 100644
--- a/aten/src/ATen/cuda/tunable/Tunable.cpp
+++ b/aten/src/ATen/cuda/tunable/Tunable.cpp
@@ -65,16 +65,17 @@
std::scoped_lock l{lock_};
auto kernel_map_it = results_.find(op_signature);
if (kernel_map_it == results_.cend()) {
- TUNABLE_LOG3("missing op_signature, returning null ResultEntry");
+ TUNABLE_LOG3("missing op_signature, returning null ResultEntry for ", op_signature, ",", params_signature);
return ResultEntry::Null();
}
const auto& km = kernel_map_it->second;
auto it = km.find(params_signature);
if (it == km.cend()) {
- TUNABLE_LOG3("missing params_signature, returning null ResultEntry");
+ TUNABLE_LOG3("missing params_signature, returning null ResultEntry for ", op_signature, ",", params_signature);
return ResultEntry::Null();
}
+ TUNABLE_LOG3("ResultEntry found for ", op_signature, ",", params_signature);
return it->second;
}
@@ -282,6 +283,7 @@
}
TuningStatus TuningResultsValidator::ValidatePyTorchVersion(const std::string& value) const {
+ TUNABLE_LOG1("PT_VERSION validation: expect ", value, " to match ", GetPyTorchVersion());
if (value == GetPyTorchVersion()) {
return OK;
}
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 2efac15..779ada5 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -187,7 +187,10 @@
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
"ROCBLAS_VERSION",
[rocblas_version]() { return rocblas_version; },
- [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
+ [rocblas_version](auto&& k) {
+ TUNABLE_LOG1("ROCBLAS_VERSION validation: expect ", k, " to match ", rocblas_version);
+ return rocblas_version == k ? OK : FAIL;
+ });
}
}
@@ -205,6 +208,7 @@
"HIPBLASLT_VERSION",
[hipblaslt_version]() { return hipblaslt_version; },
[hipblaslt_version](auto&& k) {
+ TUNABLE_LOG1("HIPBLASLT_VERSION validation: expect ", k, " to match ", hipblaslt_version);
return hipblaslt_version == k ? OK : FAIL;
});
}
@@ -217,7 +221,10 @@
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
"ROCM_VERSION",
[rocm_version]() { return rocm_version; },
- [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
+ [rocm_version](auto&& k) {
+ TUNABLE_LOG1("ROCM_VERSION validation: expect ", k, " to match ", rocm_version);
+ return rocm_version == k ? OK : FAIL;
+ });
}
if (validators.find("GCN_ARCH_NAME") == validators.end()) {
@@ -225,7 +232,10 @@
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
"GCN_ARCH_NAME",
[gcn_arch_name]() { return gcn_arch_name; },
- [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
+ [gcn_arch_name](auto&& k) {
+ TUNABLE_LOG1("GCN_ARCH_NAME validation: expect ", k, " to match ", gcn_arch_name);
+ return gcn_arch_name == k ? OK : FAIL;
+ });
}
}
#endif