Skip to content

Commit

Permalink
Ft exit (#83)
Browse files Browse the repository at this point in the history
* first attempt at clean exit. Does not cleanly exit in console.

* rewrite broad exception.

* change poetry version

* same env name in tests, codestyle changes
  • Loading branch information
bheijden authored Feb 24, 2022
1 parent e801da5 commit 0da70fc
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 71 deletions.
3 changes: 3 additions & 0 deletions eagerx/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def initialize(self, *args, **kwargs):
"""A method to initialize this node."""
pass

def shutdown(self):
pass


class Node(BaseNode):
def __init__(self, **kwargs):
Expand Down
46 changes: 36 additions & 10 deletions eagerx/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from eagerx.core.constants import process

# OTHER IMPORTS
import atexit
import abc
import cv2
import numpy as np
Expand Down Expand Up @@ -50,21 +51,25 @@ def __init__(self, name: str, rate: float, graph: Graph, bridge: BridgeSpec) ->
self.ns = "/" + name
self.rate = rate
self.initialized = False
self.has_shutdown = False

# Take deepcopy of bridge
bridge = BridgeSpec(bridge.params)
self._bridge_name = bridge.params["config"]["entity_id"]

# Register graph
self.graph = graph
nodes, objects, actions, observations, self.render_node = graph.register()

# Initialize supervisor node
self.mb, self.supervisor_node, _ = self._init_supervisor(bridge, nodes, objects)
self.mb, self.supervisor_node, self.supervisor = self._init_supervisor(bridge, nodes, objects)
self._is_initialized = self.supervisor_node.is_initialized

# Initialize bridge
self._init_bridge(bridge, nodes)

# Create environment node
self.env_node, _ = self._init_environment(actions, observations, self.supervisor_node, self.mb)
self.env_node, self.env = self._init_environment(actions, observations, self.supervisor_node, self.mb)

# Register render node
if self.render_node:
Expand All @@ -76,6 +81,9 @@ def __init__(self, name: str, rate: float, graph: Graph, bridge: BridgeSpec) ->
# Register objects
self.register_objects(objects)

# Implement clean up
atexit.register(self.shutdown)

def _init_supervisor(self, bridge: BridgeSpec, nodes: List[NodeSpec], objects: List[ObjectSpec]):
# Initialize supervisor
supervisor = self.create_supervisor()
Expand Down Expand Up @@ -250,6 +258,7 @@ def _init_environment(self, actions: NodeSpec, observations: NodeSpec, superviso

@property
def observation_space(self) -> gym.spaces.Dict:
assert not self.has_shutdown, "This environment has been shutdown."
observation_space = dict()
for name, buffer in self.env_node.observation_buffer.items():
space = buffer["converter"].get_space()
Expand All @@ -263,6 +272,7 @@ def observation_space(self) -> gym.spaces.Dict:

@property
def action_space(self) -> gym.spaces.Dict:
assert not self.has_shutdown, "This environment has been shutdown."
action_space = dict()
for name, buffer in self.env_node.action_buffer.items():
action_space[name] = buffer["converter"].get_space()
Expand Down Expand Up @@ -317,6 +327,7 @@ def _initialize(self) -> None:
rospy.loginfo("Pipelines initialized.")

def _reset(self, states: Dict) -> Dict:
assert not self.has_shutdown, "This environment has been shutdown."
# Initialize environment
if not self.initialized:
self._initialize()
Expand All @@ -332,6 +343,7 @@ def _reset(self, states: Dict) -> Dict:
def _step(self, action: Dict) -> Dict:
# Check that nodes were previously initialized.
assert self.initialized, "Not yet initialized. Call .reset() before calling .step()."
assert not self.has_shutdown, "This environment has been shutdown."

# Set actions in buffer
self._set_action(action)
Expand All @@ -341,16 +353,27 @@ def _step(self, action: Dict) -> Dict:
return self._get_observation()

def _shutdown(self):
for name in self.supervisor_node.launch_nodes:
self.supervisor_node.launch_nodes[name].terminate()
try:
rosparam.delete_param(f"/{self.name}")
rospy.loginfo(f'Parameters under namespace "/{self.name}" deleted.')
except ROSMasterException as e:
rospy.logwarn(e)
# rospy.signal_shutdown(f"[/{name}] Terminating.")
if not self.has_shutdown:
for address, node in self.supervisor_node.launch_nodes.items():
rospy.loginfo(f"[{self.name}] Send termination signal to '{address}'.")
node.terminate()
# node.terminate(f"[{self.name}] Terminating '{address}'")
for _, rxnode in self.supervisor_node.sp_nodes.items():
rxnode: RxNode
rospy.loginfo(f"[{self.name}][{rxnode.name}] Shutting down.")
rxnode.node_shutdown()
self.supervisor.node_shutdown()
self.env.node_shutdown()
self.mb.shutdown()
try:
rosparam.delete_param(f"/{self.name}")
rospy.loginfo(f'Parameters under namespace "/{self.name}" deleted.')
except ROSMasterException as e:
rospy.logwarn(e)
self.has_shutdown = True

def register_nodes(self, nodes: Union[List[NodeSpec], NodeSpec]) -> None:
assert not self.has_shutdown, "This environment has been shutdown."
# Look-up via <env_name>/<obj_name>/nodes/<component_type>/<component>: /rx/obj/nodes/sensors/pos_sensors
if not isinstance(nodes, list):
nodes = [nodes]
Expand All @@ -359,6 +382,7 @@ def register_nodes(self, nodes: Union[List[NodeSpec], NodeSpec]) -> None:
[self.supervisor_node.register_node(n) for n in nodes]

def register_objects(self, objects: Union[List[ObjectSpec], ObjectSpec]) -> None:
assert not self.has_shutdown, "This environment has been shutdown."
# Look-up via <env_name>/<obj_name>/nodes/<component_type>/<component>: /rx/obj/nodes/sensors/pos_sensors
if not isinstance(objects, list):
objects = [objects]
Expand All @@ -367,6 +391,7 @@ def register_objects(self, objects: Union[List[ObjectSpec], ObjectSpec]) -> None
[self.supervisor_node.register_object(o, self._bridge_name) for o in objects]

def render(self, mode="human"):
assert not self.has_shutdown, "This environment has been shutdown."
if self.render_node:
if mode == "human":
self.supervisor_node.start_render()
Expand Down Expand Up @@ -403,6 +428,7 @@ def step(self, action: Dict) -> Tuple[Dict, float, bool, Dict]:
pass

def close(self):
assert not self.has_shutdown, "This environment has been shutdown."
self.supervisor_node.stop_render()

def shutdown(self):
Expand Down
34 changes: 25 additions & 9 deletions eagerx/core/executable_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_param_with_blocking,
get_opposite_msg_cls,
)
from eagerx.core.executable_node import RxNode
from eagerx.utils.node_utils import wait_for_node_initialization
from eagerx.core.constants import log_levels_ROS

Expand All @@ -27,6 +28,7 @@ def __init__(self, name, message_broker):
self.ns = "/".join(name.split("/")[:2])
self.mb = message_broker
self.initialized = False
self.has_shutdown = False

# Prepare input & output topics
(
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self, name, message_broker):
self.cond_reg = Condition()

# Prepare closing routine
rospy.on_shutdown(self._close)
rospy.on_shutdown(self.node_shutdown)

def node_initialized(self):
with self.cond_reg:
Expand All @@ -65,8 +67,8 @@ def node_initialized(self):

# Notify env that node is initialized
if not self.initialized:
init_pub = rospy.Publisher(self.name + "/initialized", UInt64, queue_size=0, latch=True)
init_pub.publish(UInt64(data=1))
self.init_pub = rospy.Publisher(self.name + "/initialized", UInt64, queue_size=0, latch=True)
self.init_pub.publish(UInt64(data=1))
rospy.loginfo('Node "%s" initialized.' % self.name)
self.initialized = True

Expand Down Expand Up @@ -114,11 +116,24 @@ def _prepare_io_topics(self, name):
node,
)

def _close(self):
def _shutdown(self):
rospy.logdebug(f"[{self.name}] RxBridge._shutdown() called.")
self.init_pub.unregister()

def node_shutdown(self):
rospy.logdebug(f"[{self.name}] RxBridge.node_shutdown() called.")
for address, node in self.bridge.launch_nodes.items():
rospy.loginfo(f"[{self.name}] Send termination signal to '{address}'.")
node.terminate()
for _, rxnode in self.bridge.sp_nodes.items():
rxnode: RxNode
rospy.loginfo(f"[{self.name}] Shutting down '{rxnode.name}'.")
rxnode.node_shutdown()
rospy.loginfo(f"[{self.name}] Shutting down.")
for address, node in self.bridge.launch_nodes:
rospy.loginfo(f"[{self.name}] Terminating '{address}'")
node.terminate(f"[{self.name}] Terminating '{address}'")
self._shutdown()
self.bridge.shutdown()
self.mb.shutdown()
self.has_shutdown = True


if __name__ == "__main__":
Expand All @@ -143,5 +158,6 @@ def _close(self):

rospy.spin()
finally:
rospy.loginfo(f"[{ns}/{name}] Terminating.")
rospy.signal_shutdown(f"[{ns}/{name}] Terminating.")
if not pnode.has_shutdown:
rospy.loginfo(f"[{ns}/{name}] Send termination signal to '{ns}/{name}'.")
rospy.signal_shutdown(f"[{ns}/{name}] Terminating '{ns}/{name}'.")
22 changes: 17 additions & 5 deletions eagerx/core/executable_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, name, message_broker, **kwargs):
self.ns = "/".join(name.split("/")[:2])
self.mb = message_broker
self.initialized = False
self.has_shutdown = False

# Prepare input & output topics
(
Expand All @@ -51,12 +52,12 @@ def __init__(self, name, message_broker, **kwargs):
self.mb.add_rx_objects(node_name=name, node=self, **rx_objects)

# Prepare closing routine
rospy.on_shutdown(self._close)
rospy.on_shutdown(self.node_shutdown)

def node_initialized(self):
# Notify env that node is initialized
init_pub = rospy.Publisher(self.name + "/initialized", UInt64, queue_size=0, latch=True)
init_pub.publish(UInt64())
self.init_pub = rospy.Publisher(self.name + "/initialized", UInt64, queue_size=0, latch=True)
self.init_pub.publish(UInt64())

if not self.initialized:
rospy.loginfo('Node "%s" initialized.' % self.name)
Expand Down Expand Up @@ -133,8 +134,17 @@ def _prepare_io_topics(self, name, **kwargs):
node,
)

def _close(self):
def _shutdown(self):
rospy.logdebug(f"[{self.name}] RxNode._shutdown() called.")
self.init_pub.unregister()

def node_shutdown(self):
rospy.logdebug(f"[{self.name}] RxNode.node_shutdown() called.")
rospy.loginfo(f"[{self.name}] Shutting down.")
self._shutdown()
self.node.shutdown()
self.mb.shutdown()
self.has_shutdown = True


if __name__ == "__main__":
Expand All @@ -159,4 +169,6 @@ def _close(self):

rospy.spin()
finally:
rospy.signal_shutdown(f"Terminating '{ns}/{name}'")
if not pnode.has_shutdown:
rospy.loginfo(f"[{ns}/{name}] Send termination signal to '{ns}/{name}'.")
rospy.signal_shutdown(f"Terminating '{ns}/{name}'")
4 changes: 4 additions & 0 deletions eagerx/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,7 @@ def callback(self, t_n: float, image: Optional[Msg] = None):
# Fill output_msg with 'done' output --> signals that we are done rendering
output_msgs = dict(done=UInt64())
return output_msgs

def shutdown(self):
rospy.logdebug(f"[{self.name}] {self.name}.shutdown() called.")
cv2.destroyAllWindows()
35 changes: 31 additions & 4 deletions eagerx/core/rx_message_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, owner):
self.connected_ros = dict()
self.connected_rx = dict()

# All publishers and subscribers (grouped to unregister when shutting down)
self._publishers = []
self.subscribers = []

# Every method is wrapped in a 'with Condition' block in order to be threadsafe
def __getattribute__(self, name):
attr = super(RxMessageBroker, self).__getattribute__(name)
Expand Down Expand Up @@ -161,11 +165,13 @@ def add_rx_objects(
on_next=i["msg_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self._publishers.append(i["msg_pub"])
i["reset_pub"] = rospy.Publisher(i["address"] + "/reset", UInt64, queue_size=0, latch=True)
i["reset"].subscribe(
on_next=i["reset_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self._publishers.append(i["reset_pub"])
for i in feedthrough:
address = i["address"]
cname_address = f"{i['feedthrough_to']}:{address}"
Expand Down Expand Up @@ -208,6 +214,7 @@ def add_rx_objects(
on_next=i["msg_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self._publishers.append(i["msg_pub"])
for i in state_inputs:
address = i["address"]
try:
Expand Down Expand Up @@ -280,6 +287,7 @@ def add_rx_objects(
on_next=i["msg_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self._publishers.append(i["msg_pub"])
for i in reactive_proxy:
address = i["address"]
cname_address = f"{i['name']}:{address}"
Expand All @@ -300,6 +308,7 @@ def add_rx_objects(
on_next=i["reset_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self._publishers.append(i["reset_pub"])

# Add new addresses to already registered I/Os
for key in n.keys():
Expand Down Expand Up @@ -430,7 +439,7 @@ def connect_io(self, print_status=True):
rate_str = "|" + "".center(3, " ")
msg_type = entry["msg_type"]
self.connected_ros[node_name][key][cname_address] = entry
T = from_topic(msg_type, address, node_name=node_name)
T = from_topic(msg_type, address, node_name, self.subscribers)

# Subscribe and change status
entry["disposable"] = T.subscribe(entry["rx"])
Expand Down Expand Up @@ -470,13 +479,31 @@ def _split_cname_address(self, cname_address):
def _assert_already_registered(self, name, d, component):
assert name not in d[component], f'Cannot re-register the same address ({name}) twice as "{component}".'

def shutdown(self):
rospy.logdebug(f"[{self.owner}] RxMessageBroker.shutdown() called.")
[pub.unregister() for pub in self._publishers]
[sub.unregister() for sub in self.subscribers]


def from_topic(topic_type: Any, topic_name: str, node_name) -> Observable:
def from_topic(topic_type: Any, topic_name: str, node_name, subscribers: list) -> Observable:
def _subscribe(observer, scheduler=None) -> Disposable:
try:
rospy.Subscriber(topic_name, topic_type, lambda msg: observer.on_next(msg))
wrapped_sub = []

def cb_from_topic(msg, wrapped_sub):
try:
observer.on_next(msg)
except rospy.exceptions.ROSException as e:
sub = wrapped_sub[0]
sub.unregister()
rospy.logdebug(f"[{sub.name}]: Unregistered this subscription because of exception: {e}")

sub = rospy.Subscriber(topic_name, topic_type, callback=cb_from_topic, callback_args=wrapped_sub)
wrapped_sub.append(sub)
subscribers.append(sub)
except Exception as e:
print("[%s]: %s" % (node_name, e))
rospy.logwarn("[%s]: %s" % (node_name, e))
raise
return observer

return create(_subscribe)
Loading

0 comments on commit 0da70fc

Please sign in to comment.