[profiler] Add an option initialize kineto profiler on start up (#87226) (#88020)

Summary:
# Initialize Kineto Profiler for on-demand profiling

## TLDR
Overall this patch enables initializing the kineto profiling library on start-up. This is guarded by an env variable that is described a bit more later. The kineto profiler is otherwise initialized lazily when pytorch profiler is invoked.

## Background
We are enabling on-demand profiling capability for pytorch. As users run large distributed training flows this will enable one to capture a pytorch profiler/GPU trace remotely, from outside the process. The kineto library and a monitoring daemon - dynolog- interact to achieve this.

Dynolog will be open sourced by end of October, and has been dogfooded on Meta AI Research cluster.
https://github.com/facebookincubator/dynolog

### How it works
Kineto library registers itself with the dynolog daemon running on the host over inter process communication
```
  | kineto  |   --> (ipcfabric)  --> | dynolog |
   * register()
   * poll for on-demand tracing configs()
```
This feature is currently enabled by setting the env variable `KINETO_USE_DAEMON`.  However, it only works if we initialize kineto, else the thread to talk to dynolog is not spun up.

Related PRs in kineto include
https://github.com/pytorch/kineto/pull/637
https://github.com/pytorch/kineto/pull/653

## TestPlan:
Build pytorch from source (need to set USE_LITE_INTERPRETER_PROFILER=OFF)

Run a simple linear model [example](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html).

### First run with the env variable set
```
export KINETO_CONFIG=/private/home/bcoutinho//libkineto.conf
export KINETO_USE_DAEMON=1
python3 /private/home/bcoutinho/linear_model.py
```
Output
```
INFO:2022-10-18 09:01:12 4169946:4169946 init.cpp:98] Registering daemon config loader
cuda:0
```
We can trigger a trace using the dynolog client tool
```
#> dyno gputrace --log-file /tmp/gpu_trace_test.json
response length = 147
response = {"activityProfilersBusy":0,"activityProfilersTriggered":[4116844],"eventProfilersBusy":0,"eventProfilersTriggered":[],"processesMatched":[4116844]}
Matched 1 processes
Trace output files will be written to:
    /tmp/gpu_trace_test_4116844.json
```

### Run without env variable.
```
 python3 ../../linear_model.py
cuda:0
99 1425.056884765625
10099 8.817168235778809
```

## Side effects to initialization

Currently the environment should guard users from picking this change up unless intended. The libkineto_init does setup CUPTI APIs and spins up a thread to read on-demand configurations. This should not be problematic, we can provide a more granular init in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87226

Reviewed By: chaekit

Differential Revision: D40558184

Pulled By: briancoutinho

fbshipit-source-id: afea7502b1d72201c00994c87fde63a35783f4d5

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88020
Approved by: https://github.com/chaekit
diff --git a/torch/csrc/profiler/kineto_client_interface.cpp b/torch/csrc/profiler/kineto_client_interface.cpp
index 1eea700..c9e07ca 100644
--- a/torch/csrc/profiler/kineto_client_interface.cpp
+++ b/torch/csrc/profiler/kineto_client_interface.cpp
@@ -1,6 +1,7 @@
 #ifdef USE_KINETO
 #include <libkineto.h>
 #include <torch/csrc/autograd/profiler_kineto.h>
+#include <cstdlib>
 
 // Ondemand tracing is not supported on Apple or edge platform
 #if defined(__APPLE__) || defined(EDGE_PROFILER_USE_KINETO)
@@ -61,12 +62,22 @@
 } // namespace profiler
 
 #if ENABLE_GLOBAL_OBSERVER
+namespace {
+
 struct RegisterLibKinetoClient {
   RegisterLibKinetoClient() {
     static profiler::impl::LibKinetoClient client;
+
+    if (std::getenv("KINETO_USE_DAEMON") != nullptr) {
+      libkineto_init(/*cpuOnly=*/false, /*logOnError=*/true);
+      libkineto::api().suppressLogMessages();
+    }
+
     libkineto::api().registerClient(&client);
   }
 } register_libkineto_client;
+
+} // namespace
 #endif
 
 } // namespace torch