rpc_flight_recorder: add shard updating option
BUG=chromium:715386
TEST=run locally
Change-Id: I45a2f065d00672a375e008067a608953835ce490
Reviewed-on: https://chromium-review.googlesource.com/533638
Commit-Ready: Chris Ching <chingcodes@chromium.org>
Tested-by: Chris Ching <chingcodes@chromium.org>
Reviewed-by: Chris Ching <chingcodes@chromium.org>
diff --git a/site_utils/rpc_flight_recorder.py b/site_utils/rpc_flight_recorder.py
index bda878d..bd9e88c 100755
--- a/site_utils/rpc_flight_recorder.py
+++ b/site_utils/rpc_flight_recorder.py
@@ -13,6 +13,10 @@
from autotest_lib.client.common_lib import global_config
from autotest_lib.frontend.afe.json_rpc import proxy
from autotest_lib.server import frontend
+# import needed to setup host_attributes
+# pylint: disable=unused-import
+from autotest_lib.server import site_host_attributes
+from autotest_lib.site_utils import server_manager_utils
from chromite.lib import commandline
from chromite.lib import cros_logging as logging
from chromite.lib import metrics
@@ -27,7 +31,6 @@
proxy.JSONRPCException: 'JSONRPCException',
}
-
def afe_rpc_call(hostname):
"""Perform one rpc call set on server
@@ -42,34 +45,147 @@
logging.exception(e)
+def update_shards(shards, shards_lock, period=600, stop_event=None):
+ """Updates dict of shards
+
+ @param shards: list of shards to be updated
+ @param shards_lock: shared lock for accessing shards
+ @param period: time between polls
+ @param stop_event: Event that can be set to stop polling
+ """
+ while(not stop_event or not stop_event.is_set()):
+ start_time = time.time()
+
+ logging.debug('Updating Shards')
+ new_shards = set(server_manager_utils.get_shards())
+
+ with shards_lock:
+ current_shards = set(shards)
+ rm_shards = current_shards - new_shards
+ add_shards = new_shards - current_shards
+
+ if rm_shards:
+ for s in rm_shards:
+ shards.remove(s)
+
+ if add_shards:
+ shards.extend(add_shards)
+
+ if rm_shards:
+ logging.info('Servers left production: %s', str(rm_shards))
+
+ if add_shards:
+ logging.info('Servers entered production: %s',
+ str(add_shards))
+
+ wait_time = (start_time + period) - time.time()
+ if wait_time > 0:
+ time.sleep(wait_time)
+
+
+def poll_rpc_servers(servers, servers_lock, shards=None, period=60,
+ stop_event=None):
+ """Blocking function that polls all servers and shards
+
+ @param servers: list of servers to poll
+ @param servers_lock: lock to be used when accessing servers or shards
+ @param shards: list of shards to poll
+ @param period: time between polls
+ @param stop_event: Event that can be set to stop polling
+ """
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() * 4)
+
+ while(not stop_event or not stop_event.is_set()):
+ start_time = time.time()
+ with servers_lock:
+ all_servers = set(servers).union(shards)
+
+ logging.debug('Starting Server Polling: %s', ', '.join(all_servers))
+ pool.map(afe_rpc_call, all_servers)
+
+ logging.debug('Finished Server Polling')
+
+ metrics.Counter(METRIC_TICK).increment()
+
+ wait_time = (start_time + period) - time.time()
+ if wait_time > 0:
+ time.sleep(wait_time)
+
+
class RpcFlightRecorder(object):
"""Monitors a list of AFE"""
- def __init__(self, servers, poll_period=60):
+ def __init__(self, servers, with_shards=True, poll_period=60):
"""
- @pram servers: list of afe services to monitor
- @pram poll_period: frequency to poll all services, in seconds
+ @param servers: list of afe services to monitor
+ @param with_shards: also record status on shards
+ @param poll_period: frequency to poll all services, in seconds
"""
- self._servers = set(servers)
+ self._manager = multiprocessing.Manager()
+
self._poll_period = poll_period
- self._pool = multiprocessing.Pool(processes=20)
+
+ self._servers = self._manager.list(servers)
+ self._servers_lock = self._manager.RLock()
+
+ self._with_shards = with_shards
+ self._shards = self._manager.list()
+ self._update_shards_ps = None
+ self._poll_rpc_server_ps = None
+
+ self._stop_event = multiprocessing.Event()
+
+ def start(self):
+ """Call to start recorder"""
+ if(self._with_shards):
+ shard_args = [self._shards, self._servers_lock]
+ shard_kwargs = {'stop_event': self._stop_event}
+ self._update_shards_ps = multiprocessing.Process(
+ name='update_shards',
+ target=update_shards,
+ args=shard_args,
+ kwargs=shard_kwargs)
+
+ self._update_shards_ps.start()
+
+ poll_args = [self._servers, self._servers_lock]
+ poll_kwargs= {'shards':self._shards,
+ 'period':self._poll_period,
+ 'stop_event':self._stop_event}
+ self._poll_rpc_server_ps = multiprocessing.Process(
+ name='poll_rpc_servers',
+ target=poll_rpc_servers,
+ args=poll_args,
+ kwargs=poll_kwargs)
+
+ self._poll_rpc_server_ps.start()
+
+ def close(self):
+ """Send close event to all sub processes"""
+ self._stop_event.set()
- def poll_servers(self):
- """Blocking function that polls all servers and shards"""
- while(True):
- start_time = time.time()
- logging.debug('Starting Server Polling: %s' %
- ', '.join(self._servers))
+ def termitate(self):
+ """Terminate processes"""
+ self.close()
+ if self._poll_rpc_server_ps:
+ self._poll_rpc_server_ps.terminate()
- self._pool.map(afe_rpc_call, self._servers)
+ if self._update_shards_ps:
+ self._update_shards_ps.terminate()
- logging.debug('Finished Server Polling')
+ if self._manager:
+ self._manager.shutdown()
- metrics.Counter(METRIC_TICK).increment()
- wait_time = (start_time + self._poll_period) - time.time()
- if wait_time > 0:
- time.sleep(wait_time)
+ def join(self, timeout=None):
+ """Blocking call until closed and processes complete
+
+ @param timeout: passed to each process, so could be >timeout"""
+ if self._poll_rpc_server_ps:
+ self._poll_rpc_server_ps.join(timeout)
+
+ if self._update_shards_ps:
+ self._update_shards_ps.join(timeout)
def _failed(fields, msg_str, reason, err=None):
"""Mark current run failed
@@ -149,6 +265,10 @@
parser.add_argument('-p', '--poll-period', type=int, default=60,
help='Frequency to poll AFE servers')
+
+ parser.add_argument('--no-shards', action='store_false', dest='with_shards',
+ help='Disable shard updating')
+
return parser
@@ -167,9 +287,12 @@
with ts_mon_config.SetupTsMonGlobalState('rpc_flight_recorder',
indirect=True):
- afe_monitor = RpcFlightRecorder(options.afe,
- poll_period=options.poll_period)
- afe_monitor.poll_servers()
+ flight_recorder = RpcFlightRecorder(options.afe,
+ with_shards=options.with_shards,
+ poll_period=options.poll_period)
+
+ flight_recorder.start()
+ flight_recorder.join()
if __name__ == '__main__':