blob: b2b6afa7a632f22fe14dc937b6d226209f85ff53 [file] [log] [blame]
from contextlib import contextmanager
from typing import cast
from . import api
from . import TensorPipeAgent
@contextmanager
def _group_membership_management(store, name, is_join):
token_key = "RpcGroupManagementToken"
my_token = f"Token{name}-{int(is_join)}"
while True:
# Retrieve token from store to signal start of rank join/leave critical section
returned = store.compare_set(token_key, "", my_token).decode()
if returned == my_token:
# Yield to the function this context manager wraps
yield
# Finished, now exit and release token
# Update from store to signal end of rank join/leave critical section
store.set(token_key, "")
# Other will wait for this token to be set before they execute
store.set(my_token, "Done")
break
else:
# token_name = returned.split("-")[0]
# Store will wait for the token to be released
store.wait([returned])
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
return ret