Skip to content

Commit

Permalink
Simulartor hang oom (NVIDIA#1159)
Browse files Browse the repository at this point in the history
* Fixed the simulator hang when client run OOM.

* Removed no used comment.

* Fixed a typo.

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
yhwen and YuanTingHsieh authored Dec 7, 2022
1 parent 7dac994 commit 62ac8c6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
30 changes: 26 additions & 4 deletions nvflare/private/fed/client/fed_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,17 @@ def send_heartbeat(self, project_name):
except FLCommunicationError:
self.communicator.heartbeat_done = True

def quit_remote(self, project_name, fl_ctx: FLContext):
"""Sending the last message to the server before leaving.
Args:
fl_ctx: FLContext
Returns: N/A
"""
return self.communicator.quit_remote(self.servers, project_name, self.token, self.ssid, fl_ctx)

def heartbeat(self):
"""Sends a heartbeat from the client to the server."""
pool = None
Expand Down Expand Up @@ -341,17 +352,22 @@ def start_heartbeat(self):
heartbeat_thread = threading.Thread(target=self.run_heartbeat)
heartbeat_thread.start()

def quit_remote(self, task_name, fl_ctx: FLContext):
"""Sending the last message to the server before leaving.
def logout_client(self, fl_ctx: FLContext):
"""Logout the client from the server.
Args:
task_name: task name
fl_ctx: FLContext
Returns: N/A
"""
return self.communicator.quit_remote(self.servers, task_name, self.token, self.ssid, fl_ctx)
pool = None
try:
pool = ThreadPool(len(self.servers))
return pool.map(partial(self.quit_remote, fl_ctx=fl_ctx), tuple(self.servers))
finally:
if pool:
pool.terminate()

def set_client_engine(self, engine):
self.engine = engine
Expand All @@ -362,4 +378,10 @@ def close(self):
if self.overseer_agent:
self.overseer_agent.end()

if self.engine:
fl_ctx = self.engine.new_context()
else:
fl_ctx = FLContext()
self.logout_client(fl_ctx)

return 0
20 changes: 10 additions & 10 deletions nvflare/private/fed/server/fed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,21 @@ def remove_dead_clients(self):
if client.last_connect_time < time.time() - self.heart_beat_timeout:
delete.append(token)
for token in delete:
client = self.client_manager.remove_client(token)
self.remove_client_data(token)
if self.admin_server:
self.admin_server.client_dead(token)
self.notify_dead_client(client)
client = self.logout_client(token)
self.logger.info(
"Remove the dead Client. Name: {}\t Token: {}. Total clients: {}".format(
client.name, token, len(self.client_manager.get_clients())
)
)

def logout_client(self, token):
client = self.client_manager.remove_client(token)
self.remove_client_data(token)
if self.admin_server:
self.admin_server.client_dead(token)
self.notify_dead_client(client)
return client

def notify_dead_client(self, client):
"""Called to do further processing of the dead client
Expand Down Expand Up @@ -390,11 +394,7 @@ def Quit(self, request, context):
client = self.client_manager.validate_client(request, context)
if client:
token = client.get_token()

_ = self.client_manager.remove_client(token)
self.tokens.pop(token, None)
if self.admin_server:
self.admin_server.client_dead(token)
self.logout_client(token)

return fed_msg.FederatedSummary(comment="Removed client")

Expand Down
3 changes: 0 additions & 3 deletions nvflare/private/fed/simulator/simulator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ def _submit_update(self, data, shared_fl_context):
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
server_runner.process_submission(client, contribution_task_name, task_id, shareable, fl_ctx)

def remove_dead_clients(self):
pass

def _aux_communicate(self, fl_ctx, shareable, shared_fl_context, topic):
try:
with self.engine.lock:
Expand Down

0 comments on commit 62ac8c6

Please sign in to comment.