Skip to content

Commit

Permalink
add data set flag for vinn cache feature
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyzhaozh committed Dec 13, 2023
1 parent 7d3f51c commit d065a7b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vinn_cache_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def main(args):
#################################################

ckpt_path = args.ckpt_path
dataset_dir = args.dataset_dir
ckpt_name = pathlib.PurePath(ckpt_path).name
dataset_name = ckpt_name.split('-')[1]
repr_type = ckpt_name.split('-')[0]
seed = int(ckpt_name.split('-')[-1][:-3])

dataset_dir = f'/home/mobile-aloha/data/{dataset_name}'

episode_idxs = [int(name.split('_')[1].split('.')[0]) for name in os.listdir(dataset_dir) if ('.hdf5' in name) and ('features' not in name)]
episode_idxs.sort()
assert len(episode_idxs) == episode_idxs[-1] + 1 # no holes
Expand Down Expand Up @@ -138,6 +137,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='cache features')
parser.add_argument('--ckpt_path', type=str, required=True, help='ckpt_path')
parser.add_argument('--dataset_dir', type=str, required=True, help='dataset_dir')
args = parser.parse_args()

main(args)

0 comments on commit d065a7b

Please sign in to comment.