Skip to content

Commit

Permalink
suport omni-directional-camera
Browse files Browse the repository at this point in the history
  • Loading branch information
masayoshi-nakamura committed Apr 19, 2016
1 parent a2b16d5 commit 0558b30
Show file tree
Hide file tree
Showing 7 changed files with 1,030 additions and 59 deletions.
16 changes: 13 additions & 3 deletions python-agent/cnn_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@ class CnnDqnAgent(object):
cnn_feature_extractor = 'alexnet_feature_extractor.pickle'
model = 'bvlc_alexnet.caffemodel'
model_type = 'alexnet'
image_feature_dim = 256 * 6 * 6
image_feature_dim = 256 * 6 * 6 * 4

def _osb_to_vec(self, observation):
return np.r_[self.feature_extractor.feature(observation["image"][0]),
self.feature_extractor.feature(observation["image"][1]),
self.feature_extractor.feature(observation["image"][2]),
self.feature_extractor.feature(observation["image"][3]),
observation["depth"][0],
observation["depth"][1],
observation["depth"][2],
observation["depth"][3]]

def agent_init(self, **options):
self.use_gpu = options['use_gpu']
Expand All @@ -41,7 +51,7 @@ def agent_init(self, **options):
self.q_net = QNet(self.use_gpu, self.actions, self.q_net_input_dim)

def agent_start(self, observation):
obs_array = np.r_[self.feature_extractor.feature(observation["image"]), observation["depth"]]
obs_array = self._osb_to_vec(observation)

# Initialize State
self.state = np.zeros((self.q_net.hist_size, self.q_net_input_dim), dtype=np.uint8)
Expand All @@ -62,7 +72,7 @@ def agent_start(self, observation):
return return_action

def agent_step(self, reward, observation):
obs_array = np.r_[self.feature_extractor.feature(observation["image"]), observation["depth"]]
obs_array = self._osb_to_vec(observation)

#obs_processed = np.maximum(obs_array, self.last_observation) # Take maximum from two frames

Expand Down
18 changes: 11 additions & 7 deletions python-agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,20 @@ class AgentServer(WebSocket):
log_file = 'reward.log'
reward_sum = 0
depth_image_dim = 32 * 32
depth_image_count = 4

def received_message(self, m):
payload = m.data

dat = msgpack.unpackb(payload)
image = Image.open(io.BytesIO(bytearray(dat['image'])))
depth = Image.open(io.BytesIO(bytearray(dat['depth'])))
# depth.save("depth_" + str(self.cycle_counter) + ".png")
# image.save("image_" + str(self.cycle_counter) + ".png")

image = []
for i in xrange(4):
image.append(Image.open(io.BytesIO(bytearray(dat['image'][i]))))
depth = []
for i in xrange(4):
d = (Image.open(io.BytesIO(bytearray(dat['depth'][i]))))
depth.append(np.array(ImageOps.grayscale(d)).reshape(self.depth_image_dim))

depth = np.array(ImageOps.grayscale(depth)).reshape(self.depth_image_dim)
observation = {"image": image, "depth": depth}
reward = dat['reward']
end_episode = dat['endEpisode']
Expand All @@ -61,7 +64,7 @@ def received_message(self, m):
print ("initializing agent...")
self.agent.agent_init(
use_gpu=args.gpu,
depth_image_dim=self.depth_image_dim)
depth_image_dim=self.depth_image_dim * self.depth_image_count)

action = self.agent.agent_start(observation)
self.send(str(action))
Expand All @@ -87,6 +90,7 @@ def received_message(self, m):

self.thread_event.set()


cherrypy.config.update({'server.socket_port': args.port})
WebSocketPlugin(cherrypy.engine).subscribe()
cherrypy.tools.websocket = WebSocketTool()
Expand Down
Loading

0 comments on commit 0558b30

Please sign in to comment.