Address review comments about XlaConfigRegistry.
1. Fix an oversight that the macro does not get expanded.
2. Add a mutex to the Register() method.
3. Comment polishing.
diff --git a/tensorflow/core/util/xla_config_registry.h b/tensorflow/core/util/xla_config_registry.h
index 99bcff1..7d1cb50 100644
--- a/tensorflow/core/util/xla_config_registry.h
+++ b/tensorflow/core/util/xla_config_registry.h
@@ -18,6 +18,7 @@
#include <functional>
#include "tensorflow/core/platform/default/logging.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -26,18 +27,22 @@
// its status.
class XlaConfigRegistry {
public:
+ // XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
+ // single gpu and general (multi-gpu) graphs.
struct XlaGlobalJitLevel {
OptimizerOptions::GlobalJitLevel single_gpu;
OptimizerOptions::GlobalJitLevel general;
};
- // Input is jit_level in session config, and return is the config from
- // XLA, reflecting the effect of the environment variable flags.
+ // Input is the jit_level in session config, and return value is the jit_level
+ // from XLA, reflecting the effect of the environment variable flags.
typedef std::function<XlaGlobalJitLevel(
const OptimizerOptions::GlobalJitLevel&)>
GlobalJitLevelGetterTy;
static void Register(XlaConfigRegistry::GlobalJitLevelGetterTy getter) {
+ static mutex mu(LINKER_INITIALIZED);
+ mutex_lock l(mu);
CHECK(!global_jit_level_getter_);
global_jit_level_getter_ = std::move(getter);
}
@@ -55,7 +60,13 @@
};
#define REGISTER_XLA_CONFIG_GETTER(getter) \
- static bool registered_##__COUNTER__ = \
+ REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter)
+
+#define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \
+ REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter)
+
+#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
+ static bool xla_config_registry_registration_##ctr = \
(::tensorflow::XlaConfigRegistry::Register(getter), true)
} // namespace tensorflow