blob: 8403dd6f21aa85cd75c6cbeae7e6f72ab21c07f1 [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2020 - The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from mobly import signals
from threading import Condition
from cert.event_stream import static_remaining_time_delta
from cert.truth import assertThat
class IHasBehaviors(ABC):
@abstractmethod
def get_behaviors(self):
pass
def anything():
return lambda obj: True
def when(has_behaviors):
assertThat(isinstance(has_behaviors, IHasBehaviors)).isTrue()
return has_behaviors.get_behaviors()
def IGNORE_UNHANDLED(obj):
pass
class SingleArgumentBehavior(object):
def __init__(self, reply_stage_factory):
self._reply_stage_factory = reply_stage_factory
self._instances = []
self._invoked_obj = []
self._invoked_condition = Condition()
self.set_default_to_crash()
def begin(self, matcher):
return PersistenceStage(self, matcher, self._reply_stage_factory)
def append(self, behavior_instance):
self._instances.append(behavior_instance)
def set_default(self, fn):
assertThat(fn).isNotNone()
self._default_fn = fn
def set_default_to_crash(self):
self._default_fn = None
def set_default_to_ignore(self):
self._default_fn = IGNORE_UNHANDLED
def run(self, obj):
for instance in self._instances:
if instance.try_run(obj):
self.__obj_invoked(obj)
return
if self._default_fn is not None:
# IGNORE_UNHANDLED is also a default fn
self._default_fn(obj)
self.__obj_invoked(obj)
else:
raise signals.TestFailure(
"%s: behavior for %s went unhandled" % (self._reply_stage_factory().__class__.__name__, obj),
extras=None)
def __obj_invoked(self, obj):
self._invoked_condition.acquire()
self._invoked_obj.append(obj)
self._invoked_condition.notify()
self._invoked_condition.release()
def wait_until_invoked(self, matcher, times, timeout):
end_time = datetime.now() + timeout
invoked_times = 0
while datetime.now() < end_time and invoked_times < times:
remaining = static_remaining_time_delta(end_time)
invoked_times = sum((matcher(i) for i in self._invoked_obj))
self._invoked_condition.acquire()
self._invoked_condition.wait(remaining.total_seconds())
self._invoked_condition.release()
return invoked_times == times
class PersistenceStage(object):
def __init__(self, behavior, matcher, reply_stage_factory):
self._behavior = behavior
self._matcher = matcher
self._reply_stage_factory = reply_stage_factory
def then(self, times=1):
reply_stage = self._reply_stage_factory()
reply_stage.init(self._behavior, self._matcher, times)
return reply_stage
def always(self):
return self.then(times=-1)
class ReplyStage(object):
def init(self, behavior, matcher, persistence):
self._behavior = behavior
self._matcher = matcher
self._persistence = persistence
def _commit(self, fn):
self._behavior.append(BehaviorInstance(self._matcher, self._persistence, fn))
class BehaviorInstance(object):
def __init__(self, matcher, persistence, fn):
self._matcher = matcher
self._persistence = persistence
self._fn = fn
self._called_count = 0
def try_run(self, obj):
if not self._matcher(obj):
return False
if self._persistence >= 0:
if self._called_count >= self._persistence:
return False
self._called_count += 1
self._fn(obj)
return True
class BoundVerificationStage(object):
def __init__(self, behavior, matcher, timeout):
self._behavior = behavior
self._matcher = matcher
self._timeout = timeout
def times(self, times=1):
return self._behavior.wait_until_invoked(self._matcher, times, self._timeout)
class WaitForBehaviorSubject(object):
def __init__(self, behaviors, timeout):
self._behaviors = behaviors
self._timeout = timeout
def __getattr__(self, item):
behavior = getattr(self._behaviors, item + "_behavior")
t = self._timeout
return lambda matcher: BoundVerificationStage(behavior, matcher, t)
def wait_until(i_has_behaviors, timeout=timedelta(seconds=3)):
return WaitForBehaviorSubject(i_has_behaviors.get_behaviors(), timeout)