blob: ef4c8b82efa69c7b63c03f350c5271b0215fdc2f [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
#include <mutex> // NOLINT(build/c++11)
// We need a pair of compile time and runtime flags to disable compilation of
// custom contraction kernels for unsupported architectures (e.g. Android,
// iOS, ARM and PPC CPUs, etc...), and to be able to fallback on default Eigen
// matrix multiplication at runtime.
//
// It's not allowed to use absl flags library in Tensorflow, so we have to pass
// the configuration through the environment variable.
//
// Example:
// bazel test --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
namespace Eigen {
namespace internal {
// TODO(ezhulenev): This is a temporary workaround for disabling custom kernels
// at runtime in tests. We should always rely on compile time flags for that.
// Example: ... --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false //test
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() {
static bool use_custom_contraction_kernel = true;
// This subroutine should not be used in GPU. In case it is, a custom kernel
// should always be used
#if !defined __NVCC__ && !defined __HIP_DEVICE_COMPILE__
static std::once_flag initialized;
std::call_once(initialized, [&] {
char* flag = std::getenv("TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL");
if (flag && (strcmp(flag, "false") == 0 || strcmp(flag, "0") == 0)) {
use_custom_contraction_kernel = false;
}
});
#endif
return use_custom_contraction_kernel;
}
} // namespace internal
} // namespace Eigen
#endif