Skip to content

Commit

Permalink
env returns camera obs as dict
Browse files Browse the repository at this point in the history
  • Loading branch information
lukashermann committed Nov 22, 2021
1 parent accfb2e commit 1883198
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 25 deletions.
20 changes: 4 additions & 16 deletions calvin_env/envs/play_lmp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,8 @@

import gym

try:
from lfp.datasets.utils.episode_utils import process_actions, process_depth, process_rgb, process_state
except ImportError:
from calvin_agent.datasets.utils.episode_utils import (
process_actions,
process_depth,
process_rgb,
process_state,
)

from calvin_agent.datasets.utils.episode_utils import process_actions, process_depth, process_rgb, process_state
import numpy as np
import torch

Expand Down Expand Up @@ -54,14 +47,9 @@ def set_egl_device(device):
logger.info(f"EGL_DEVICE_ID {egl_id} <==> CUDA_DEVICE_ID {cuda_id}")

def transform_observation(self, obs: Dict[str, Any]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
rgb_obs = {f"rgb_{name}": img for name, img in zip([cam.name for cam in self.env.cameras], obs["rgb_obs"])}
depth_obs = {
f"depth_{name}": img for name, img in zip([cam.name for cam in self.env.cameras], obs["depth_obs"])
}

state_obs = process_state(obs, self.observation_space_keys, self.transforms, self.proprio_state)
rgb_obs = process_rgb(rgb_obs, self.observation_space_keys, self.transforms)
depth_obs = process_depth(depth_obs, self.observation_space_keys, self.transforms)
rgb_obs = process_rgb(obs["rgb_obs"], self.observation_space_keys, self.transforms)
depth_obs = process_depth(obs["depth_obs"], self.observation_space_keys, self.transforms)

state_obs["robot_obs"] = state_obs["robot_obs"].to(self.device).unsqueeze(0)
rgb_obs.update({"rgb_obs": {k: v.to(self.device).unsqueeze(0) for k, v in rgb_obs["rgb_obs"].items()}})
Expand Down
18 changes: 9 additions & 9 deletions calvin_env/envs/play_table_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ def render(self, mode="human"):
"""render is gym compatibility function"""
rgb_obs, depth_obs = self.get_camera_obs()
if mode == "human":
if len(rgb_obs) == 0:
log.warning("Environment does not have camera")
if "rgb_static" not in rgb_obs:
log.warning("Environment does not have static camera")
return
img = rgb_obs[0][:, :, ::-1]
img = rgb_obs["rgb_static"][:, :, ::-1]
cv2.imshow("simulation cam", cv2.resize(img, (500, 500)))
cv2.waitKey(1)
elif mode == "rgb_array":
assert len(rgb_obs) > 0, "Environment does not have camera"
return rgb_obs[0]
assert "rgb_static" in rgb_obs, "Environment does not have static camera"
return rgb_obs["rgb_static"]
else:
raise NotImplementedError

Expand All @@ -177,12 +177,12 @@ def seed(self, seed=None):

def get_camera_obs(self):
assert self.cameras is not None
rgb_obs = []
depth_obs = []
rgb_obs = {}
depth_obs = {}
for cam in self.cameras:
rgb, depth = cam.render()
rgb_obs.append(rgb)
depth_obs.append(depth)
rgb_obs[f"rgb_{cam.name}"] = rgb
depth_obs[f"depth_{cam.name}"] = depth
return rgb_obs, depth_obs

def get_obs(self):
Expand Down

0 comments on commit 1883198

Please sign in to comment.