blob: 4cf74f40c342a6d243bec11af429ad5bce9d926c [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Plugins for Google Cloud Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import six
from google.api_core import retry
from google.cloud import tasks_v2
from google3.google.protobuf import timestamp_pb2
from tradefed_cluster.plugins import base
DEFAULT_RETRY_OPTION = retry.Retry(deadline=60)
class TaskScheduler(base.TaskScheduler):
"""A task scheduler for Cloud Tasks."""
def __init__(self, project, location):
"""Constructor.
Args:
project: a project name.
location: a project location.
"""
self._project = project
self._location = location
self._thread_storage = threading.local()
def _GetClient(self):
"""Cloud tasks client is not threadsafe, store it TLS."""
if hasattr(self._thread_storage, 'tasks_client'):
return self._thread_storage.tasks_client
self._thread_storage.tasks_client = tasks_v2.CloudTasksClient()
return self._thread_storage.tasks_client
def AddTask(self, queue_name, payload, target, task_name, eta):
parent = self._GetClient().queue_path(
self._project, self._location, queue_name)
if not isinstance(payload, bytes):
payload = six.ensure_binary(payload)
task = {
'app_engine_http_request': {
'http_method': 'POST',
'app_engine_routing': {},
'relative_uri': '/_ah/queue/%s' % queue_name,
'body': payload
}
}
if task_name:
task['name'] = self._GetClient().task_path(
self._project, self._location, queue_name, task_name)
if eta:
timestamp = timestamp_pb2.Timestamp()
timestamp.FromDatetime(eta)
task['schedule_time'] = timestamp
if target:
task['app_engine_http_request']['app_engine_routing']['service'] = target
task = self._GetClient().create_task(parent=parent, task=task,
retry=DEFAULT_RETRY_OPTION)
return base.Task(name=task.name.split('/')[-1], payload=payload, eta=eta)
def DeleteTask(self, queue_name, task_name):
task_path = self._GetClient().task_path(
self._project, self._location, queue_name, task_name)
self._GetClient().delete_task(task_path,
retry=DEFAULT_RETRY_OPTION)