rpc_flight_recorder: add shard updating option

BUG=chromium:715386
TEST=run locally

Change-Id: I45a2f065d00672a375e008067a608953835ce490
Reviewed-on: https://chromium-review.googlesource.com/533638
Commit-Ready: Chris Ching <chingcodes@chromium.org>
Tested-by: Chris Ching <chingcodes@chromium.org>
Reviewed-by: Chris Ching <chingcodes@chromium.org>
diff --git a/site_utils/rpc_flight_recorder.py b/site_utils/rpc_flight_recorder.py
index bda878d..bd9e88c 100755
--- a/site_utils/rpc_flight_recorder.py
+++ b/site_utils/rpc_flight_recorder.py
@@ -13,6 +13,10 @@
 from autotest_lib.client.common_lib import global_config
 from autotest_lib.frontend.afe.json_rpc import proxy
 from autotest_lib.server import frontend
+# import needed to setup host_attributes
+# pylint: disable=unused-import
+from autotest_lib.server import site_host_attributes
+from autotest_lib.site_utils import server_manager_utils
 from chromite.lib import commandline
 from chromite.lib import cros_logging as logging
 from chromite.lib import metrics
@@ -27,7 +31,6 @@
         proxy.JSONRPCException: 'JSONRPCException',
         }
 
-
 def afe_rpc_call(hostname):
     """Perform one rpc call set on server
 
@@ -42,34 +45,147 @@
         logging.exception(e)
 
 
+def update_shards(shards, shards_lock, period=600, stop_event=None):
+    """Updates dict of shards
+
+    @param shards: list of shards to be updated
+    @param shards_lock: shared lock for accessing shards
+    @param period: time between polls
+    @param stop_event: Event that can be set to stop polling
+    """
+    while(not stop_event or not stop_event.is_set()):
+        start_time = time.time()
+
+        logging.debug('Updating Shards')
+        new_shards = set(server_manager_utils.get_shards())
+
+        with shards_lock:
+            current_shards = set(shards)
+            rm_shards = current_shards - new_shards
+            add_shards = new_shards - current_shards
+
+            if rm_shards:
+                for s in rm_shards:
+                    shards.remove(s)
+
+            if add_shards:
+                shards.extend(add_shards)
+
+        if rm_shards:
+            logging.info('Servers left production: %s', str(rm_shards))
+
+        if add_shards:
+            logging.info('Servers entered production: %s',
+                    str(add_shards))
+
+        wait_time = (start_time + period) - time.time()
+        if wait_time > 0:
+            time.sleep(wait_time)
+
+
+def poll_rpc_servers(servers, servers_lock, shards=None, period=60,
+                     stop_event=None):
+    """Blocking function that polls all servers and shards
+
+    @param servers: list of servers to poll
+    @param servers_lock: lock to be used when accessing servers or shards
+    @param shards: list of shards to poll
+    @param period: time between polls
+    @param stop_event: Event that can be set to stop polling
+    """
+    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() * 4)
+
+    while(not stop_event or not stop_event.is_set()):
+        start_time = time.time()
+        with servers_lock:
+            all_servers = set(servers).union(shards)
+
+        logging.debug('Starting Server Polling: %s', ', '.join(all_servers))
+        pool.map(afe_rpc_call, all_servers)
+
+        logging.debug('Finished Server Polling')
+
+        metrics.Counter(METRIC_TICK).increment()
+
+        wait_time = (start_time + period) - time.time()
+        if wait_time > 0:
+            time.sleep(wait_time)
+
+
 class RpcFlightRecorder(object):
     """Monitors a list of AFE"""
-    def __init__(self, servers, poll_period=60):
+    def __init__(self, servers, with_shards=True, poll_period=60):
         """
-        @pram servers: list of afe services to monitor
-        @pram poll_period: frequency to poll all services, in seconds
+        @param servers: list of afe services to monitor
+        @param with_shards: also record status on shards
+        @param poll_period: frequency to poll all services, in seconds
         """
-        self._servers = set(servers)
+        self._manager = multiprocessing.Manager()
+
         self._poll_period = poll_period
-        self._pool = multiprocessing.Pool(processes=20)
+
+        self._servers = self._manager.list(servers)
+        self._servers_lock = self._manager.RLock()
+
+        self._with_shards = with_shards
+        self._shards = self._manager.list()
+        self._update_shards_ps = None
+        self._poll_rpc_server_ps = None
+
+        self._stop_event = multiprocessing.Event()
+
+    def start(self):
+        """Call to start recorder"""
+        if(self._with_shards):
+            shard_args = [self._shards, self._servers_lock]
+            shard_kwargs = {'stop_event': self._stop_event}
+            self._update_shards_ps = multiprocessing.Process(
+                    name='update_shards',
+                    target=update_shards,
+                    args=shard_args,
+                    kwargs=shard_kwargs)
+
+            self._update_shards_ps.start()
+
+        poll_args = [self._servers, self._servers_lock]
+        poll_kwargs= {'shards':self._shards,
+                     'period':self._poll_period,
+                     'stop_event':self._stop_event}
+        self._poll_rpc_server_ps = multiprocessing.Process(
+                name='poll_rpc_servers',
+                target=poll_rpc_servers,
+                args=poll_args,
+                kwargs=poll_kwargs)
+
+        self._poll_rpc_server_ps.start()
+
+    def close(self):
+        """Send close event to all sub processes"""
+        self._stop_event.set()
 
 
-    def poll_servers(self):
-        """Blocking function that polls all servers and shards"""
-        while(True):
-            start_time = time.time()
-            logging.debug('Starting Server Polling: %s' %
-                          ', '.join(self._servers))
+    def termitate(self):
+        """Terminate processes"""
+        self.close()
+        if self._poll_rpc_server_ps:
+            self._poll_rpc_server_ps.terminate()
 
-            self._pool.map(afe_rpc_call, self._servers)
+        if self._update_shards_ps:
+            self._update_shards_ps.terminate()
 
-            logging.debug('Finished Server Polling')
+        if self._manager:
+            self._manager.shutdown()
 
-            metrics.Counter(METRIC_TICK).increment()
 
-            wait_time = (start_time + self._poll_period) - time.time()
-            if wait_time > 0:
-                time.sleep(wait_time)
+    def join(self, timeout=None):
+        """Blocking call until closed and processes complete
+
+        @param timeout: passed to each process, so could be >timeout"""
+        if self._poll_rpc_server_ps:
+            self._poll_rpc_server_ps.join(timeout)
+
+        if self._update_shards_ps:
+            self._update_shards_ps.join(timeout)
 
 def _failed(fields, msg_str, reason, err=None):
     """Mark current run failed
@@ -149,6 +265,10 @@
 
     parser.add_argument('-p', '--poll-period', type=int, default=60,
                         help='Frequency to poll AFE servers')
+
+    parser.add_argument('--no-shards', action='store_false', dest='with_shards',
+                        help='Disable shard updating')
+
     return parser
 
 
@@ -167,9 +287,12 @@
 
     with ts_mon_config.SetupTsMonGlobalState('rpc_flight_recorder',
                                              indirect=True):
-        afe_monitor = RpcFlightRecorder(options.afe,
-                                        poll_period=options.poll_period)
-        afe_monitor.poll_servers()
+        flight_recorder = RpcFlightRecorder(options.afe,
+                                            with_shards=options.with_shards,
+                                            poll_period=options.poll_period)
+
+        flight_recorder.start()
+        flight_recorder.join()
 
 
 if __name__ == '__main__':