| # Copyright 2019 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """A commander module to process requests and schedule commands.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import json |
| import logging |
| import zlib |
| |
| import flask |
| |
| |
| from tradefed_cluster import command_manager |
| from tradefed_cluster import command_monitor |
| from tradefed_cluster import command_task_store |
| from tradefed_cluster import common |
| from tradefed_cluster import datastore_entities |
| from tradefed_cluster import env_config |
| from tradefed_cluster import metric |
| from tradefed_cluster import request_manager |
| from tradefed_cluster.util import command_util |
| |
| REQUEST_HANDLER_PATH = "/_ah/queue/%s" % request_manager.REQUEST_QUEUE |
| |
| # The max number of command shards per request. |
| # LINT.IfChange(max_shard_count) |
| DEFAULT_MAX_SHARDS = 20 |
| RUN_TARGET_TO_MAX_SHARDS_MAP = { |
| # Allow more shards for virtual device run targets. |
| "RemoteAvdIDevice": 100, |
| "TcpDevice": 100 |
| } |
| |
| |
| APP = flask.Flask(__name__) |
| |
| |
| @common.RetryNdbContentionErrors |
| def _ProcessRequest(request_id): |
| """Process a request and schedule corresponding commands. |
| |
| Args: |
| request_id: request id, str |
| """ |
| request_id = str(request_id) |
| request = request_manager.GetRequest(request_id) |
| if not request: |
| logging.error("Request %d doesn't exist in ds.", request_id) |
| return |
| logging.debug("Processing request %s: %s", |
| request_id, request) |
| |
| # It is important to make the following block to be idempotent since it can be |
| # retried. |
| try: |
| # We don't have to worry about a case where only some commands are created |
| # because all commands are created in a transaction |
| # (see command_manager.CreateCommands). |
| commands = command_manager.GetCommands(request_id) |
| if not commands: |
| commands = _CreateCommands(request) |
| if request.max_concurrent_tasks: |
| # Only schedule (request.max_concurrent_tasks) commands. |
| commands = commands[:request.max_concurrent_tasks] |
| |
| # If the command is in an UNKNOWN state, it means that not all tasks may |
| # have been scheduled for the command. Since we may retry processing a |
| # request for various reasons, we should only try to schedule commands that |
| # are in an UNKNOWN state. |
| pending_commands = [ |
| c for c in commands if c.state == common.CommandState.UNKNOWN |
| ] |
| if pending_commands: |
| command_manager.ScheduleTasks( |
| pending_commands, update_request_state=False) |
| command_monitor.Monitor(commands) |
| |
| # Update the request state to start request event processing. |
| request_manager.EvaluateState(request_id) |
| except (AssertionError, ValueError) as e: |
| logging.exception("Failed to process request %s", request_id) |
| cancel_reason = None |
| if isinstance(e, ValueError): |
| cancel_reason = common.CancelReason.INVALID_REQUEST |
| request_manager.CancelRequest(request_id, cancel_reason) |
| command_manager.CancelCommands(request_id, cancel_reason) |
| |
| |
| def _CreateCommands(request): |
| """Create a list of commands for a request.""" |
| expanded_command_infos = [] |
| shard_indexes = [] |
| for command_info in request.command_infos: |
| if command_info.cluster is None: |
| raise ValueError("cluster is not specified.") |
| if not command_info.run_target: |
| raise ValueError("run target is not defined.") |
| # TODO: Check in db to see that it is a valid run target. |
| if command_info.run_count < 1: |
| raise ValueError("run count must be equal or greater than 1.") |
| |
| max_shards = RUN_TARGET_TO_MAX_SHARDS_MAP.get( |
| command_info.run_target, DEFAULT_MAX_SHARDS) |
| if not 0 < command_info.shard_count <= max_shards: |
| raise ValueError("shard count %d is outside of range [1, %d]" % |
| (command_info.shard_count, max_shards)) |
| # TODO: Move validity check to request_manager. |
| |
| command_line = command_util.CommandLine(command_info.command_line) |
| command_line.RemoveOptions([ |
| # TFC-specific options |
| "--cluster", |
| "--run-target", |
| "--run-count", |
| |
| # TF conflicting options |
| "--loop", # causes TF to loop test runs continuously |
| "--product-type", # causes TF to fail device allocations |
| "--test-iterations", # specifies the number of iterations to run |
| ]) |
| # Schedule commands and tag them with a run_target. |
| # TF implicitly knows how to map a device to a run_target string. When |
| # fetching commands, TF looks for only commands tagged with run_targets |
| # which are available on itself. |
| for shard_index in range(command_info.shard_count): |
| # If the request is unmanaged, use command line to inject shard |
| # parameters. |
| if not request.type: |
| # If local sharding was defined keep the original shard setup |
| local_sharding = False |
| if command_line.GetOption( |
| "--shard-count") is not None and command_line.GetOption( |
| "--shard-index") is None: |
| local_sharding = True |
| |
| if not local_sharding: |
| command_line.RemoveOptions(["--shard-count", "--shard-index"]) |
| if command_info.shard_count > 1: |
| command_line.AddOption( |
| "--shard-count", str(command_info.shard_count)) |
| command_line.AddOption("--shard-index", str(shard_index)) |
| expanded_command_infos.append( |
| datastore_entities.CommandInfo( |
| name=command_info.name, |
| command_line=command_line.ToTFString(), |
| cluster=command_info.cluster, |
| run_target=command_info.run_target, |
| run_count=command_info.run_count, |
| shard_count=command_info.shard_count, |
| allow_partial_device_match=( |
| command_info.allow_partial_device_match), |
| test_bench=command_info.test_bench |
| )) |
| shard_indexes.append(shard_index) |
| |
| commands = command_manager.CreateCommands( |
| request_id=request.key.id(), |
| request_plugin_data=request.plugin_data, |
| command_infos=expanded_command_infos, |
| shard_indexes=shard_indexes, |
| priority=request.priority, |
| queue_timeout_seconds=request.queue_timeout_seconds, |
| request_type=request.type, |
| affinity_tag=request.affinity_tag) |
| if request.prev_test_context: |
| for command in commands: |
| command_manager.UpdateTestContext( |
| request_id=request.key.id(), |
| command_id=command.key.id(), |
| test_context=request.prev_test_context) |
| return commands |
| |
| |
| @APP.route(REQUEST_HANDLER_PATH, methods=["POST"]) |
| def HandleRequest(): |
| """Process a request message.""" |
| body = flask.request.get_data() |
| try: |
| body = zlib.decompress(body) |
| except zlib.error: |
| logging.warning( |
| "payload may not be compressed: %s", body, exc_info=True) |
| payload = json.loads(body) |
| request_id = payload["id"] |
| _ProcessRequest(request_id) |
| return common.HTTP_OK |
| |
| |
| def ProcessCommandEvent(event): |
| """Updates state of a command and coordinate command tasks. |
| |
| Args: |
| event: a CommandEvent |
| """ |
| command = command_manager.GetCommand(event.request_id, event.command_id) |
| if not command: |
| logging.warning( |
| "unknown command %s %s; ignored", event.request_id, event.command_id) |
| return |
| |
| is_updated = command_manager.UpdateCommandAttempt(event) |
| |
| # No need to coordinate if the event is old but continue if it is final. |
| # We continue if it is final as datastore update to the command and the |
| # attempt aren't done in the same transaction, so a failed command update |
| # should be retried in that event. |
| if not is_updated and not common.IsFinalCommandState(event.attempt_state): |
| logging.debug("Command attempt is not updated.") |
| return |
| |
| if common.IsFinalCommandState(event.attempt_state): |
| metric.RecordCommandAttemptMetric( |
| cluster_id=command.cluster, |
| run_target=command.run_target, |
| state=event.attempt_state.name) |
| |
| task = command_task_store.GetTask(event.task_id) |
| if task and task.attempt_id and event.attempt_id != task.attempt_id: |
| logging.debug( |
| "Skipping command update. Event attempt_id %s does not " |
| "match the task attempt_id %s", event.attempt_id, task.attempt_id) |
| else: |
| # Update command. |
| command = command_manager.UpdateState( |
| event.request_id, |
| event.command_id, |
| attempt_state=event.attempt_state, |
| task_id=event.task_id) |
| |
| if common.IsFinalCommandState(command.state): |
| # Deschedule command since the state indicates that it is not supposed |
| # to run anymore. |
| logging.debug("Command %r is finalized, delete all its tasks.", |
| command.key) |
| command_manager.DeleteTasks(command) |
| |
| # Update AnTS. |
| env_config.CONFIG.plugin.OnProcessCommandEvent( |
| command_manager.GetCommand(event.request_id, event.command_id), |
| command_manager.GetCommandAttempt( |
| event.request_id, event.command_id, event.attempt_id), |
| event_data=event.data) |
| |
| # Update request. |
| request, request_summary = request_manager.EvaluateState(event.request_id) |
| |
| _CheckPendingCommands(request, request_summary) |
| |
| |
| def _CheckPendingCommands(request, request_summary): |
| """Check pending commands and schedule if necessary.""" |
| if common.IsFinalRequestState(request.state): |
| return |
| if not request.max_concurrent_tasks: |
| return |
| if not request_summary or request_summary.pending_count <= 0: |
| return |
| logging.info( |
| "Checking pending commands for request %s: max_concurrent_tasks=%d)", |
| request.key.id(), request.max_concurrent_tasks) |
| active_command_count = ( |
| request_summary.queued_count + request_summary.running_count) |
| logging.info( |
| "active_command_count = %d, pending_command_count = %d", |
| |
| active_command_count, request_summary.pending_count) |
| next_command_count = min( |
| request.max_concurrent_tasks - active_command_count, |
| request_summary.pending_count) |
| logging.info("next_command_count = %d", next_command_count) |
| if 0 < next_command_count: |
| logging.info("Scheduling %d next commands", next_command_count) |
| next_commands = command_manager.GetCommands( |
| request.key.id(), common.CommandState.UNKNOWN)[:next_command_count] |
| command_manager.ScheduleTasks(next_commands) |
| command_monitor.Monitor(next_commands) |