## This file is part of Scapy
## See http://www.secdev.org/projects/scapy for more informations
## Copyright (C) Philippe Biondi <phil@secdev.org>
## Copyright (C) Gabriel Potter <gabriel@potter.fr>
## This program is published under a GPLv2 license

"""
Automata with states, transitions and actions.
"""

from __future__ import absolute_import
import types,itertools,time,os,sys,socket,traceback
from select import select
from collections import deque
import threading
from scapy.config import conf
from scapy.utils import do_graph
from scapy.error import log_interactive
from scapy.plist import PacketList
from scapy.data import MTU
from scapy.supersocket import SuperSocket
from scapy.consts import WINDOWS
from scapy.compat import *
import scapy.modules.six as six

try:
    import thread
except ImportError:
    THREAD_EXCEPTION = RuntimeError
else:
    THREAD_EXCEPTION = thread.error

if WINDOWS:
    from scapy.error import Scapy_Exception
    recv_error = Scapy_Exception
else:
    recv_error = ()

""" In Windows, select.select is not available for custom objects. Here's the implementation of scapy to re-create this functionnality
# Passive way: using no-ressources locks
               +---------+             +---------------+      +-------------------------+
               |  Start  +------------->Select_objects +----->+Linux: call select.select|
               +---------+             |(select.select)|      +-------------------------+
                                       +-------+-------+
                                               |
                                          +----v----+               +--------+
                                          | Windows |               |Time Out+----------------------------------+
                                          +----+----+               +----+---+                                  |
                                               |                         ^                                      |
      Event                                    |                         |                                      |
        +                                      |                         |                                      |
        |                              +-------v-------+                 |                                      |
        |                       +------+Selectable Sel.+-----+-----------------+-----------+                    |
        |                       |      +-------+-------+     |           |     |           v              +-----v-----+
+-------v----------+            |              |             |           |     |        Passive lock<-----+release_all<------+
|Data added to list|       +----v-----+  +-----v-----+  +----v-----+     v     v            +             +-----------+      |
+--------+---------+       |Selectable|  |Selectable |  |Selectable|   ............         |                                |
         |                 +----+-----+  +-----------+  +----------+                        |                                |
         |                      v                                                           |                                |
         v                 +----+------+   +------------------+               +-------------v-------------------+            |
   +-----+------+          |wait_return+-->+  check_recv:     |               |                                 |            |
   |call_release|          +----+------+   |If data is in list|               |  END state: selectable returned |        +---+--------+
   +-----+--------              v          +-------+----------+               |                                 |        | exit door  |
         |                    else                 |                          +---------------------------------+        +---+--------+
         |                      +                  |                                                                         |
         |                 +----v-------+          |                                                                         |
         +--------->free -->Passive lock|          |                                                                         |
                           +----+-------+          |                                                                         |
                                |                  |                                                                         |
                                |                  v                                                                         |
                                +------------------Selectable-Selector-is-advertised-that-the-selectable-is-readable---------+
"""

class SelectableObject:
    """DEV: to implement one of those, you need to add 2 things to your object:
    - add "check_recv" function
    - call "self.call_release" once you are ready to be read

    You can set the __selectable_force_select__ to True in the class, if you want to
    force the handler to use fileno(). This may only be useable on sockets created using
    the builtin socket API."""
    __selectable_force_select__ = False
    def check_recv(self):
        """DEV: will be called only once (at beginning) to check if the object is ready."""
        raise OSError("This method must be overwriten.")

    def _wait_non_ressources(self, callback):
        """This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback"""
        self.trigger = threading.Lock()
        self.was_ended = False
        self.trigger.acquire()
        self.trigger.acquire()
        if not self.was_ended:
            callback(self)

    def wait_return(self, callback):
        """Entry point of SelectableObject: register the callback"""
        if self.check_recv():
            return callback(self)
        _t = threading.Thread(target=self._wait_non_ressources, args=(callback,))
        _t.setDaemon(True)
        _t.start()
        
    def call_release(self, arborted=False):
        """DEV: Must be call when the object becomes ready to read.
           Relesases the lock of _wait_non_ressources"""
        self.was_ended = arborted
        try:
            self.trigger.release()
        except (THREAD_EXCEPTION, AttributeError):
            pass

class SelectableSelector(object):
    """
    Select SelectableObject objects.
    
    inputs: objects to process
    remain: timeout. If 0, return [].
    customTypes: types of the objects that have the check_recv function.
    """
    def _release_all(self):
        """Releases all locks to kill all threads"""
        for i in self.inputs:
            i.call_release(True)
        self.available_lock.release()

    def _timeout_thread(self, remain):
        """Timeout before releasing every thing, if nothing was returned"""
        time.sleep(remain)
        if not self._ended:
            self._ended = True
            self._release_all()

    def _exit_door(self, _input):
        """This function is passed to each SelectableObject as a callback
        The SelectableObjects have to call it once there are ready"""
        self.results.append(_input)
        if self._ended:
            return
        self._ended = True
        self._release_all()
    
    def __init__(self, inputs, remain):
        self.results = []
        self.inputs = list(inputs)
        self.remain = remain
        self.available_lock = threading.Lock()
        self.available_lock.acquire()
        self._ended = False

    def process(self):
        """Entry point of SelectableSelector"""
        if WINDOWS:
            select_inputs = []
            for i in self.inputs:
                if not isinstance(i, SelectableObject):
                    warning("Unknown ignored object type: %s", type(i))
                elif i.__selectable_force_select__:
                    # Then use select.select
                    select_inputs.append(i)
                elif not self.remain and i.check_recv():
                    self.results.append(i)
                else:
                    i.wait_return(self._exit_door)
            if select_inputs:
                # Use default select function
                self.results.extend(select(select_inputs, [], [], self.remain)[0])
            if not self.remain:
                return self.results

            threading.Thread(target=self._timeout_thread, args=(self.remain,)).start()
            if not self._ended:
                self.available_lock.acquire()
            return self.results
        else:
            r,_,_ = select(self.inputs,[],[],self.remain)
            return r

def select_objects(inputs, remain):
    """
    Select SelectableObject objects. Same than:
        select.select([inputs], [], [], remain)
    But also works on Windows, only on SelectableObject.
    
    inputs: objects to process
    remain: timeout. If 0, return [].
    """
    handler = SelectableSelector(inputs, remain)
    return handler.process()

class ObjectPipe(SelectableObject):
    def __init__(self):
        self.rd,self.wr = os.pipe()
        self.queue = deque()
    def fileno(self):
        return self.rd
    def check_recv(self):
        return len(self.queue) > 0
    def send(self, obj):
        self.queue.append(obj)
        os.write(self.wr,b"X")
        self.call_release()
    def write(self, obj):
        self.send(obj)
    def recv(self, n=0):
        os.read(self.rd, 1)
        return self.queue.popleft()
    def read(self, n=0):
        return self.recv(n)

class Message:
    def __init__(self, **args):
        self.__dict__.update(args)
    def __repr__(self):
        return "<Message %s>" % " ".join("%s=%r"%(k,v)
                                         for (k,v) in six.iteritems(self.__dict__)
                                         if not k.startswith("_"))

class _instance_state:
    def __init__(self, instance):
        self.__self__ = instance.__self__
        self.__func__ = instance.__func__
        self.__self__.__class__ = instance.__self__.__class__
    def __getattr__(self, attr):
        return getattr(self.__func__, attr)
    def __call__(self, *args, **kargs):
        return self.__func__(self.__self__, *args, **kargs)
    def breaks(self):
        return self.__self__.add_breakpoints(self.__func__)
    def intercepts(self):
        return self.__self__.add_interception_points(self.__func__)
    def unbreaks(self):
        return self.__self__.remove_breakpoints(self.__func__)
    def unintercepts(self):
        return self.__self__.remove_interception_points(self.__func__)
        

##############
## Automata ##
##############

class ATMT:
    STATE = "State"
    ACTION = "Action"
    CONDITION = "Condition"
    RECV = "Receive condition"
    TIMEOUT = "Timeout condition"
    IOEVENT = "I/O event"

    class NewStateRequested(Exception):
        def __init__(self, state_func, automaton, *args, **kargs):
            self.func = state_func
            self.state = state_func.atmt_state
            self.initial = state_func.atmt_initial
            self.error = state_func.atmt_error
            self.final = state_func.atmt_final
            Exception.__init__(self, "Request state [%s]" % self.state)
            self.automaton = automaton
            self.args = args
            self.kargs = kargs
            self.action_parameters() # init action parameters
        def action_parameters(self, *args, **kargs):
            self.action_args = args
            self.action_kargs = kargs
            return self
        def run(self):
            return self.func(self.automaton, *self.args, **self.kargs)
        def __repr__(self):
            return "NewStateRequested(%s)" % self.state

    @staticmethod
    def state(initial=0,final=0,error=0):
        def deco(f,initial=initial, final=final):
            f.atmt_type = ATMT.STATE
            f.atmt_state = f.__name__
            f.atmt_initial = initial
            f.atmt_final = final
            f.atmt_error = error
            def state_wrapper(self, *args, **kargs):
                return ATMT.NewStateRequested(f, self, *args, **kargs)

            state_wrapper.__name__ = "%s_wrapper" % f.__name__
            state_wrapper.atmt_type = ATMT.STATE
            state_wrapper.atmt_state = f.__name__
            state_wrapper.atmt_initial = initial
            state_wrapper.atmt_final = final
            state_wrapper.atmt_error = error
            state_wrapper.atmt_origfunc = f
            return state_wrapper
        return deco
    @staticmethod
    def action(cond, prio=0):
        def deco(f,cond=cond):
            if not hasattr(f,"atmt_type"):
                f.atmt_cond = {}
            f.atmt_type = ATMT.ACTION
            f.atmt_cond[cond.atmt_condname] = prio
            return f
        return deco
    @staticmethod
    def condition(state, prio=0):
        def deco(f, state=state):
            f.atmt_type = ATMT.CONDITION
            f.atmt_state = state.atmt_state
            f.atmt_condname = f.__name__
            f.atmt_prio = prio
            return f
        return deco
    @staticmethod
    def receive_condition(state, prio=0):
        def deco(f, state=state):
            f.atmt_type = ATMT.RECV
            f.atmt_state = state.atmt_state
            f.atmt_condname = f.__name__
            f.atmt_prio = prio
            return f
        return deco
    @staticmethod
    def ioevent(state, name, prio=0, as_supersocket=None):
        def deco(f, state=state):
            f.atmt_type = ATMT.IOEVENT
            f.atmt_state = state.atmt_state
            f.atmt_condname = f.__name__
            f.atmt_ioname = name
            f.atmt_prio = prio
            f.atmt_as_supersocket = as_supersocket
            return f
        return deco
    @staticmethod
    def timeout(state, timeout):
        def deco(f, state=state, timeout=timeout):
            f.atmt_type = ATMT.TIMEOUT
            f.atmt_state = state.atmt_state
            f.atmt_timeout = timeout
            f.atmt_condname = f.__name__
            return f
        return deco

class _ATMT_Command:
    RUN = "RUN"
    NEXT = "NEXT"
    FREEZE = "FREEZE"
    STOP = "STOP"
    END = "END"
    EXCEPTION = "EXCEPTION"
    SINGLESTEP = "SINGLESTEP"
    BREAKPOINT = "BREAKPOINT"
    INTERCEPT = "INTERCEPT"
    ACCEPT = "ACCEPT"
    REPLACE = "REPLACE"
    REJECT = "REJECT"

class _ATMT_supersocket(SuperSocket):
    def __init__(self, name, ioevent, automaton, proto, args, kargs):
        self.name = name
        self.ioevent = ioevent
        self.proto = proto
        self.spa,self.spb = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
        kargs["external_fd"] = {ioevent:self.spb}
        self.atmt = automaton(*args, **kargs)
        self.atmt.runbg()
    def fileno(self):
        return self.spa.fileno()
    def send(self, s):
        if not isinstance(s, bytes):
            s = bytes(s)
        return self.spa.send(s)
    def recv(self, n=MTU):
        try:
            r = self.spa.recv(n)
        except recv_error:
            if not WINDOWS:
                raise
            return None
        if self.proto is not None:
            r = self.proto(r)
        return r
    def close(self):
        pass

class _ATMT_to_supersocket:
    def __init__(self, name, ioevent, automaton):
        self.name = name
        self.ioevent = ioevent
        self.automaton = automaton
    def __call__(self, proto, *args, **kargs):
        return _ATMT_supersocket(self.name, self.ioevent, self.automaton, proto, args, kargs)

class Automaton_metaclass(type):
    def __new__(cls, name, bases, dct):
        cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
        cls.states={}
        cls.state = None
        cls.recv_conditions={}
        cls.conditions={}
        cls.ioevents={}
        cls.timeout={}
        cls.actions={}
        cls.initial_states=[]
        cls.ionames = []
        cls.iosupersockets = []

        members = {}
        classes = [cls]
        while classes:
            c = classes.pop(0) # order is important to avoid breaking method overloading
            classes += list(c.__bases__)
            for k,v in six.iteritems(c.__dict__):
                if k not in members:
                    members[k] = v

        decorated = [v for v in six.itervalues(members)
                     if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")]
        
        for m in decorated:
            if m.atmt_type == ATMT.STATE:
                s = m.atmt_state
                cls.states[s] = m
                cls.recv_conditions[s]=[]
                cls.ioevents[s]=[]
                cls.conditions[s]=[]
                cls.timeout[s]=[]
                if m.atmt_initial:
                    cls.initial_states.append(m)
            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]:
                cls.actions[m.atmt_condname] = []
    
        for m in decorated:
            if m.atmt_type == ATMT.CONDITION:
                cls.conditions[m.atmt_state].append(m)
            elif m.atmt_type == ATMT.RECV:
                cls.recv_conditions[m.atmt_state].append(m)
            elif m.atmt_type == ATMT.IOEVENT:
                cls.ioevents[m.atmt_state].append(m)
                cls.ionames.append(m.atmt_ioname)
                if m.atmt_as_supersocket is not None:
                    cls.iosupersockets.append(m)
            elif m.atmt_type == ATMT.TIMEOUT:
                cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
            elif m.atmt_type == ATMT.ACTION:
                for c in m.atmt_cond:
                    cls.actions[c].append(m)
            

        for v in six.itervalues(cls.timeout):
            v.sort(key=cmp_to_key(lambda t1_f1,t2_f2: cmp(t1_f1[0],t2_f2[0])))
            v.append((None, None))
        for v in itertools.chain(six.itervalues(cls.conditions),
                                 six.itervalues(cls.recv_conditions),
                                 six.itervalues(cls.ioevents)):
            v.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio)))
        for condname,actlst in six.iteritems(cls.actions):
            actlst.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname])))

        for ioev in cls.iosupersockets:
            setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls))

        return cls

    def graph(self, **kargs):
        s = 'digraph "%s" {\n'  % self.__class__.__name__
        
        se = "" # Keep initial nodes at the begining for better rendering
        for st in six.itervalues(self.states):
            if st.atmt_initial:
                se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se
            elif st.atmt_final:
                se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state
            elif st.atmt_error:
                se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state
        s += se

        for st in six.itervalues(self.states):
            for n in st.atmt_origfunc.__code__.co_names+st.atmt_origfunc.__code__.co_consts:
                if n in self.states:
                    s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
            

        for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+
                      [("red",k,v) for k,v in self.recv_conditions.items()]+
                      [("orange",k,v) for k,v in self.ioevents.items()]):
            for f in v:
                for n in f.__code__.co_names+f.__code__.co_consts:
                    if n in self.states:
                        l = f.atmt_condname
                        for x in self.actions[f.atmt_condname]:
                            l += "\\l>[%s]" % x.__name__
                        s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c)
        for k,v in six.iteritems(self.timeout):
            for t,f in v:
                if f is None:
                    continue
                for n in f.__code__.co_names+f.__code__.co_consts:
                    if n in self.states:
                        l = "%s/%.1fs" % (f.atmt_condname,t)                        
                        for x in self.actions[f.atmt_condname]:
                            l += "\\l>[%s]" % x.__name__
                        s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l)
        s += "}\n"
        return do_graph(s, **kargs)

class Automaton(six.with_metaclass(Automaton_metaclass)):
    def parse_args(self, debug=0, store=1, **kargs):
        self.debug_level=debug
        self.socket_kargs = kargs
        self.store_packets = store        

    def master_filter(self, pkt):
        return True

    def my_send(self, pkt):
        self.send_sock.send(pkt)


    ## Utility classes and exceptions
    class _IO_fdwrapper(SelectableObject):
        def __init__(self,rd,wr):
            if WINDOWS:
                # rd will be used for reading and sending
                if isinstance(rd, ObjectPipe):
                    self.rd = rd
                else:
                    raise OSError("On windows, only instances of ObjectPipe are externally available")
            else:
                if rd is not None and not isinstance(rd, int):
                    rd = rd.fileno()
                if wr is not None and not isinstance(wr, int):
                    wr = wr.fileno()
                self.rd = rd
                self.wr = wr
        def fileno(self):
            return self.rd
        def check_recv(self):
            return self.rd.check_recv()
        def read(self, n=65535):
            if WINDOWS:
                return self.rd.recv(n)
            return os.read(self.rd, n)
        def write(self, msg):
            if WINDOWS:
                self.rd.send(msg)
                return self.call_release()
            return os.write(self.wr,msg)
        def recv(self, n=65535):
            return self.read(n)        
        def send(self, msg):
            return self.write(msg)

    class _IO_mixer(SelectableObject):
        def __init__(self,rd,wr):
            self.rd = rd
            self.wr = wr
        def fileno(self):
            if isinstance(self.rd, int):
                return self.rd
            return self.rd.fileno()
        def check_recv(self):
            return self.rd.check_recv()
        def recv(self, n=None):
            return self.rd.recv(n)
        def read(self, n=None):
            return self.recv(n)
        def send(self, msg):
            self.wr.send(msg)
            return self.call_release()
        def write(self, msg):
            return self.send(msg)


    class AutomatonException(Exception):
        def __init__(self, msg, state=None, result=None):
            Exception.__init__(self, msg)
            self.state = state
            self.result = result

    class AutomatonError(AutomatonException):
        pass
    class ErrorState(AutomatonException):
        pass
    class Stuck(AutomatonException):
        pass
    class AutomatonStopped(AutomatonException):
        pass
    
    class Breakpoint(AutomatonStopped):
        pass
    class Singlestep(AutomatonStopped):
        pass
    class InterceptionPoint(AutomatonStopped):
        def __init__(self, msg, state=None, result=None, packet=None):
            Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
            self.packet = packet

    class CommandMessage(AutomatonException):
        pass


    ## Services
    def debug(self, lvl, msg):
        if self.debug_level >= lvl:
            log_interactive.debug(msg)            

    def send(self, pkt):
        if self.state.state in self.interception_points:
            self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
            self.intercepted_packet = pkt
            cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt)
            self.cmdout.send(cmd)
            cmd = self.cmdin.recv()
            self.intercepted_packet = None
            if cmd.type == _ATMT_Command.REJECT:
                self.debug(3,"INTERCEPT: packet rejected")
                return
            elif cmd.type == _ATMT_Command.REPLACE:
                pkt = cmd.pkt
                self.debug(3,"INTERCEPT: packet replaced by: %s" % pkt.summary())
            elif cmd.type == _ATMT_Command.ACCEPT:
                self.debug(3,"INTERCEPT: packet accepted")
            else:
                raise self.AutomatonError("INTERCEPT: unkown verdict: %r" % cmd.type)
        self.my_send(pkt)
        self.debug(3,"SENT : %s" % pkt.summary())
        
        if self.store_packets:
            self.packets.append(pkt.copy())


    ## Internals
    def __init__(self, *args, **kargs):
        external_fd = kargs.pop("external_fd",{})
        self.send_sock_class = kargs.pop("ll", conf.L3socket)
        self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
        self.started = threading.Lock()
        self.threadid = None
        self.breakpointed = None
        self.breakpoints = set()
        self.interception_points = set()
        self.intercepted_packet = None
        self.debug_level=0
        self.init_args=args
        self.init_kargs=kargs
        self.io = type.__new__(type, "IOnamespace",(),{})
        self.oi = type.__new__(type, "IOnamespace",(),{})
        self.cmdin = ObjectPipe()
        self.cmdout = ObjectPipe()
        self.ioin = {}
        self.ioout = {}
        for n in self.ionames:
            extfd = external_fd.get(n)
            if not isinstance(extfd, tuple):
                extfd = (extfd,extfd)
            elif WINDOWS:
                raise OSError("Tuples are not allowed as external_fd on windows")
            ioin,ioout = extfd                
            if ioin is None:
                ioin = ObjectPipe()
            elif not isinstance(ioin, SelectableObject):
                ioin = self._IO_fdwrapper(ioin,None)
            if ioout is None:
                ioout = ioin if WINDOWS else ObjectPipe()
            elif not isinstance(ioout, SelectableObject):
                ioout = self._IO_fdwrapper(None,ioout)

            self.ioin[n] = ioin
            self.ioout[n] = ioout 
            ioin.ioname = n
            ioout.ioname = n
            setattr(self.io, n, self._IO_mixer(ioout,ioin))
            setattr(self.oi, n, self._IO_mixer(ioin,ioout))

        for stname in self.states:
            setattr(self, stname, 
                    _instance_state(getattr(self, stname)))

        self.start()

    def __iter__(self):
        return self        

    def __del__(self):
        self.stop()

    def _run_condition(self, cond, *args, **kargs):
        try:
            self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname))
            cond(self,*args, **kargs)
        except ATMT.NewStateRequested as state_req:
            self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))
            if cond.atmt_type == ATMT.RECV:
                if self.store_packets:
                    self.packets.append(args[0])
            for action in self.actions[cond.atmt_condname]:
                self.debug(2, "   + Running action [%s]" % action.__name__)
                action(self, *state_req.action_args, **state_req.action_kargs)
            raise
        except Exception as e:
            self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e))
            raise
        else:
            self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))

    def _do_start(self, *args, **kargs):
        ready = threading.Event()
        _t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs)
        _t.setDaemon(True)
        _t.start()
        ready.wait()

    def _do_control(self, ready, *args, **kargs):
        with self.started:
            self.threadid = threading.currentThread().ident

            # Update default parameters
            a = args+self.init_args[len(args):]
            k = self.init_kargs.copy()
            k.update(kargs)
            self.parse_args(*a,**k)
    
            # Start the automaton
            self.state=self.initial_states[0](self)
            self.send_sock = self.send_sock_class(**self.socket_kargs)
            self.listen_sock = self.recv_sock_class(**self.socket_kargs)
            self.packets = PacketList(name="session[%s]"%self.__class__.__name__)

            singlestep = True
            iterator = self._do_iter()
            self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
            # Sync threads
            ready.set()
            try:
                while True:
                    c = self.cmdin.recv()
                    self.debug(5, "Received command %s" % c.type)
                    if c.type == _ATMT_Command.RUN:
                        singlestep = False
                    elif c.type == _ATMT_Command.NEXT:
                        singlestep = True
                    elif c.type == _ATMT_Command.FREEZE:
                        continue
                    elif c.type == _ATMT_Command.STOP:
                        break
                    while True:
                        state = next(iterator)
                        if isinstance(state, self.CommandMessage):
                            break
                        elif isinstance(state, self.Breakpoint):
                            c = Message(type=_ATMT_Command.BREAKPOINT,state=state)
                            self.cmdout.send(c)
                            break
                        if singlestep:
                            c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
                            self.cmdout.send(c)
                            break
            except StopIteration as e:
                c = Message(type=_ATMT_Command.END, result=e.args[0])
                self.cmdout.send(c)
            except Exception as e:
                exc_info = sys.exc_info()
                self.debug(3, "Transfering exception from tid=%i:\n%s"% (self.threadid, traceback.format_exception(*exc_info)))
                m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info)
                self.cmdout.send(m)        
            self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
            self.threadid = None
    
    def _do_iter(self):
        while True:
            try:
                self.debug(1, "## state=[%s]" % self.state.state)
    
                # Entering a new state. First, call new state function
                if self.state.state in self.breakpoints and self.state.state != self.breakpointed: 
                    self.breakpointed = self.state.state
                    yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
                                          state = self.state.state)
                self.breakpointed = None
                state_output = self.state.run()
                if self.state.error:
                    raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), 
                                          result=state_output, state=self.state.state)
                if self.state.final:
                    raise StopIteration(state_output)
    
                if state_output is None:
                    state_output = ()
                elif not isinstance(state_output, list):
                    state_output = state_output,
                
                # Then check immediate conditions
                for cond in self.conditions[self.state.state]:
                    self._run_condition(cond, *state_output)
    
                # If still there and no conditions left, we are stuck!
                if ( len(self.recv_conditions[self.state.state]) == 0 and
                     len(self.ioevents[self.state.state]) == 0 and
                     len(self.timeout[self.state.state]) == 1 ):
                    raise self.Stuck("stuck in [%s]" % self.state.state,
                                     state=self.state.state, result=state_output)
    
                # Finally listen and pay attention to timeouts
                expirations = iter(self.timeout[self.state.state])
                next_timeout,timeout_func = next(expirations)
                t0 = time.time()
                
                fds = [self.cmdin]
                if len(self.recv_conditions[self.state.state]) > 0:
                    fds.append(self.listen_sock)
                for ioev in self.ioevents[self.state.state]:
                    fds.append(self.ioin[ioev.atmt_ioname])
                while True:
                    t = time.time()-t0
                    if next_timeout is not None:
                        if next_timeout <= t:
                            self._run_condition(timeout_func, *state_output)
                            next_timeout,timeout_func = next(expirations)
                    if next_timeout is None:
                        remain = None
                    else:
                        remain = next_timeout-t
    
                    self.debug(5, "Select on %r" % fds)
                    r = select_objects(fds, remain)
                    self.debug(5, "Selected %r" % r)
                    for fd in r:
                        self.debug(5, "Looking at %r" % fd)
                        if fd == self.cmdin:
                            yield self.CommandMessage("Received command message")
                        elif fd == self.listen_sock:
                            try:
                                pkt = self.listen_sock.recv(MTU)
                            except recv_error:
                                pass
                            else:
                                if pkt is not None:
                                    if self.master_filter(pkt):
                                        self.debug(3, "RECVD: %s" % pkt.summary())
                                        for rcvcond in self.recv_conditions[self.state.state]:
                                            self._run_condition(rcvcond, pkt, *state_output)
                                    else:
                                        self.debug(4, "FILTR: %s" % pkt.summary())
                        else:
                            self.debug(3, "IOEVENT on %s" % fd.ioname)
                            for ioevt in self.ioevents[self.state.state]:
                                if ioevt.atmt_ioname == fd.ioname:
                                    self._run_condition(ioevt, fd, *state_output)
    
            except ATMT.NewStateRequested as state_req:
                self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
                self.state = state_req
                yield state_req

    ## Public API
    def add_interception_points(self, *ipts):
        for ipt in ipts:
            if hasattr(ipt,"atmt_state"):
                ipt = ipt.atmt_state
            self.interception_points.add(ipt)
        
    def remove_interception_points(self, *ipts):
        for ipt in ipts:
            if hasattr(ipt,"atmt_state"):
                ipt = ipt.atmt_state
            self.interception_points.discard(ipt)

    def add_breakpoints(self, *bps):
        for bp in bps:
            if hasattr(bp,"atmt_state"):
                bp = bp.atmt_state
            self.breakpoints.add(bp)

    def remove_breakpoints(self, *bps):
        for bp in bps:
            if hasattr(bp,"atmt_state"):
                bp = bp.atmt_state
            self.breakpoints.discard(bp)

    def start(self, *args, **kargs):
        if not self.started.locked():
            self._do_start(*args, **kargs)
        
    def run(self, resume=None, wait=True):
        if resume is None:
            resume = Message(type = _ATMT_Command.RUN)
        self.cmdin.send(resume)
        if wait:
            try:
                c = self.cmdout.recv()
            except KeyboardInterrupt:
                self.cmdin.send(Message(type = _ATMT_Command.FREEZE))
                return
            if c.type == _ATMT_Command.END:
                return c.result
            elif c.type == _ATMT_Command.INTERCEPT:
                raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)
            elif c.type == _ATMT_Command.SINGLESTEP:
                raise self.Singlestep("singlestep state=[%s]"%c.state.state, state=c.state.state)
            elif c.type == _ATMT_Command.BREAKPOINT:
                raise self.Breakpoint("breakpoint triggered on state [%s]"%c.state.state, state=c.state.state)
            elif c.type == _ATMT_Command.EXCEPTION:
                six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2])

    def runbg(self, resume=None, wait=False):
        self.run(resume, wait)

    def next(self):
        return self.run(resume = Message(type=_ATMT_Command.NEXT))
    __next__ = next

    def stop(self):
        self.cmdin.send(Message(type=_ATMT_Command.STOP))
        with self.started:
            # Flush command pipes
            while True:
                r = select_objects([self.cmdin, self.cmdout], 0)
                if not r:
                    break
                for fd in r:
                    fd.recv()
                
    def restart(self, *args, **kargs):
        self.stop()
        self.start(*args, **kargs)

    def accept_packet(self, pkt=None, wait=False):
        rsm = Message()
        if pkt is None:
            rsm.type = _ATMT_Command.ACCEPT
        else:
            rsm.type = _ATMT_Command.REPLACE
            rsm.pkt = pkt
        return self.run(resume=rsm, wait=wait)

    def reject_packet(self, wait=False):
        rsm = Message(type = _ATMT_Command.REJECT)
        return self.run(resume=rsm, wait=wait)

    

