Skip to content

Commit

Permalink
Save edits from live
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Apr 5, 2022
1 parent fd61d84 commit 2407f80
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions ae/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Tuple

import cv2
import gym
import numpy as np

Expand All @@ -9,25 +10,41 @@

class AutoencoderWrapper(gym.Wrapper):
"""
Gym wrapper to encode image and reduce input dimension
using pre-trained auto-encoder
(only the encoder part is used here, decoder part can be used for debug)
Wrapper to encode input image using pre-trained AutoEncoder
:param env: Gym environment
:param ae_path: Path to the autoencoder
:param ae_path: absolute path to the pretrained AutoEncoder
"""

def __init__(self, env: gym.Env, ae_path: Optional[str] = os.environ.get("AAE_PATH")): # noqa: B008
def __init__(self, env: gym.Env, ae_path: str = os.environ["AE_PATH"]):
super().__init__(env)
assert ae_path is not None, "No path to autoencoder was provided"
self.ae = load_ae(ae_path)
# Update observation space
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.ae.z_size,), dtype=np.float32)
self.autoencoder = load_ae(ae_path)
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.autoencoder.z_size + 1,),
dtype=np.float32,
)

def reset(self) -> np.ndarray:
# Important: Convert to BGR to match OpenCV convention
return self.ae.encode_from_raw_image(self.env.reset()[:, :, ::-1]).flatten()
obs = self.env.reset()
# Convert to BGR
encoded_image = self.autoencoder.encode_from_raw_image(obs[:, :, ::-1])
new_obs = np.concatenate([encoded_image.flatten(), [0.0]])
return new_obs.flatten()

def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
obs, reward, done, infos = self.env.step(action)
return self.ae.encode_from_raw_image(obs[:, :, ::-1]).flatten(), reward, done, infos
# Encode with the pre-trained AutoEncoder
encoded_image = self.autoencoder.encode_from_raw_image(obs[:, :, ::-1])
# reconstructed_image = self.autoencoder.decode(encoded_image)[0]
# cv2.imshow("Original", obs[:, :, ::-1])
# cv2.imshow("Reconstruction", reconstructed_image)
# # stop if escape is pressed
# k = cv2.waitKey(0) & 0xFF
# if k == 27:
# pass
speed = infos["speed"]
new_obs = np.concatenate([encoded_image.flatten(), [speed]])

return new_obs.flatten(), reward, done, infos

0 comments on commit 2407f80

Please sign in to comment.