Integrates cublasLT.
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
index 6b40900..e812bf5 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
@@ -128,8 +128,7 @@
// to the index within the algorithms vector, not the algorithm
// itself.
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
- if (!se::AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
- &algorithm_config)) {
+ if (!blas_plans_autotune_cache.Find(matmul_parameters, &algorithm_config)) {
VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
<< " algorithms.";
se::blas::ProfileResult best_result;
@@ -192,8 +191,7 @@
<< " for " << trans_x << " " << trans_y << " " << m << " " << n
<< " " << k << " " << batch_size << " " << broadcast << " "
<< broadcast << " " << dtype << " " << device_id;
- se::AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
- algorithm_config);
+ blas_plans_autotune_cache.Insert(matmul_parameters, algorithm_config);
}
return Status::OK();
}
@@ -383,14 +381,14 @@
DoBlasPlansAutotune(stream, instr, allocator, input_output_allocator,
gemm_config, element_type, cublas_autotune_level,
lhs_buffer, rhs_buffer, output_buffer));
+ return {se::blas::kNoAlgorithm};
};
- return {se::blas::kNoAlgorithm};
} else {
GemmCacheKey key =
std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
instr->shape(), gemm_config.SerializeAsString());
- tensorflow::mutex_lock cache_lock(autotune_cache_mu);
+ absl::MutexLock cache_lock(&autotune_cache_mu);
auto it = autotune_cache.find(key);
int64_t autotuning_requests = cache_hits + cache_misses;
if (autotuning_requests && autotuning_requests % 10 == 0) {
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 599695f..3da3313 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -94,6 +94,25 @@
&scratch_allocator, nullptr);
}
+bool BlasPlansAutotuneCache::Find(const se::BatchMatmulParameters ¶ms,
+ se::blas::AlgorithmConfig *config) {
+ absl::MutexLock lock(&mu_);
+ auto iter = blas_plans_algorithms_map_.find(params);
+ if (iter == blas_plans_algorithms_map_.end()) {
+ return false;
+ }
+ *config = iter->second;
+ return true;
+}
+
+void BlasPlansAutotuneCache::Insert(const se::BatchMatmulParameters ¶ms,
+ const se::blas::AlgorithmConfig &config) {
+ absl::MutexLock lock(&mu_);
+ if (!blas_plans_algorithms_map_.contains(params)) {
+ blas_plans_algorithms_map_.insert({params, config});
+ }
+}
+
// Converts from an XLA PrimitiveType to a blas::ComputationType, which is
// used to specify the precision with which matmul computations should be
// performed, separately from the precision of the inputs and result.
@@ -245,16 +264,14 @@
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
// When autotuner is disabled, AutoTuneBatchMatmul singleton is empty.
- if (!se::AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
- &algorithm_config)) {
+ if (!blas_plans_autotune_cache.Find(matmul_parameters, &algorithm_config)) {
algorithm_config.set_algorithm(0);
VLOG(4) << "Autotuner disabled: Inserting algorithm id "
<< algorithm_config.algorithm() << " for " << trans_x << " "
<< trans_y << " " << m << " " << n << " " << k << " "
<< batch_size << " " << broadcast << " " << broadcast << " "
<< dtype << " " << device_id;
- se::AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
- algorithm_config);
+ blas_plans_autotune_cache.Insert(matmul_parameters, algorithm_config);
}
se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
algorithm_ptr = algorithms[algorithm_idx].get();
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index 1fee7e4..a5730f3 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -120,6 +120,31 @@
return compatible_types.contains(type);
}
+// A class for storing and retrieving algorithms in cublasLT autotuning
+class BlasPlansAutotuneCache {
+ public:
+ BlasPlansAutotuneCache() {}
+ bool Find(const se::BatchMatmulParameters& params,
+ se::blas::AlgorithmConfig* config);
+ void Insert(const se::BatchMatmulParameters& params,
+ const se::blas::AlgorithmConfig& config);
+
+ private:
+ struct Hasher {
+ std::size_t operator()(const se::BatchMatmulParameters& parameter) const {
+ return parameter.hash();
+ }
+ };
+
+ mutable absl::Mutex mu_;
+ absl::flat_hash_map<se::BatchMatmulParameters, se::blas::AlgorithmConfig,
+ Hasher>
+ blas_plans_algorithms_map_ ABSL_GUARDED_BY(mu_);
+ TF_DISALLOW_COPY_AND_ASSIGN(BlasPlansAutotuneCache);
+};
+
+static BlasPlansAutotuneCache blas_plans_autotune_cache;
+
} // namespace gpu
} // namespace xla