Skip to content

Commit

Permalink
remove double obs norm and fix assert
Browse files Browse the repository at this point in the history
  • Loading branch information
amandlek committed Aug 21, 2023
1 parent 3dacf4d commit 488516a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
6 changes: 0 additions & 6 deletions robomimic/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,6 @@ def get_item(self, index):
seq_length=self.seq_length,
prefix="obs"
)
if self.hdf5_normalize_obs:
meta["obs"] = ObsUtils.normalize_obs(meta["obs"], obs_normalization_stats=self.obs_normalization_stats)

if self.load_next_obs:
meta["next_obs"] = self.get_obs_sequence_from_demo(
Expand All @@ -454,8 +452,6 @@ def get_item(self, index):
seq_length=self.seq_length,
prefix="next_obs"
)
if self.hdf5_normalize_obs:
meta["next_obs"] = ObsUtils.normalize_obs(meta["next_obs"], obs_normalization_stats=self.obs_normalization_stats)

if goal_index is not None:
goal = self.get_obs_sequence_from_demo(
Expand All @@ -466,8 +462,6 @@ def get_item(self, index):
seq_length=1,
prefix="next_obs",
)
if self.hdf5_normalize_obs:
goal = ObsUtils.normalize_obs(goal, obs_normalization_stats=self.obs_normalization_stats)
meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal

return meta
Expand Down
3 changes: 2 additions & 1 deletion robomimic/utils/obs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def normalize_obs(obs_dict, obs_normalization_stats):
# check shape consistency
shape_len_diff = len(mean.shape) - len(obs_dict[m].shape)
assert shape_len_diff in [0, 1], "shape length mismatch in @normalize_obs"
assert mean.shape[shape_len_diff:] == obs_dict[m].shape, "shape mismatch in @normalize obs"
# if dict has no leading batch dim, check shapes match exactly, else allow first dim to broadcast
assert mean.shape[1:] == obs_dict[m].shape[(1 - shape_len_diff):], "shape mismatch in @normalize_obs"

# handle case where obs dict is not batched by removing stats batch dimension
if shape_len_diff == 1:
Expand Down

0 comments on commit 488516a

Please sign in to comment.