Skip to content

Commit

Permalink
speed up dataset loading
Browse files Browse the repository at this point in the history
  • Loading branch information
lukashermann committed Oct 27, 2021
1 parent 5e92d61 commit 937bd46
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions calvin/datasets/npz_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
import re
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -44,13 +45,21 @@ def __init__(self, *args, skip_frames: int = 0, n_digits: Optional[int] = None,
else:
self.episode_lookup, self.max_batched_length_per_demo = self.load_file_indices(self.abs_datasets_dir)

glob_generator = self.abs_datasets_dir.glob(f"*.{self.save_format}")
file_names = [x for x in glob_generator if x.is_file()]
aux_naming_pattern = re.split(r"\d+", file_names[0].stem)
self.naming_pattern = [file_names[0].parent / aux_naming_pattern[0], file_names[0].suffix]
self.n_digits = n_digits if n_digits is not None else len(re.findall(r"\d+", file_names[0].stem)[0])
assert len(self.naming_pattern) == 2
assert self.n_digits > 0
self.naming_pattern, self.n_digits = self.lookup_naming_pattern(n_digits)

def lookup_naming_pattern(self, n_digits):
it = os.scandir(self.abs_datasets_dir)
while True:
filename = Path(next(it))
if self.save_format in filename.suffix:
break
aux_naming_pattern = re.split(r"\d+", filename.stem)
naming_pattern = [filename.parent / aux_naming_pattern[0], filename.suffix]
n_digits = n_digits if n_digits is not None else len(re.findall(r"\d+", filename.stem)[0])
assert len(naming_pattern) == 2
assert n_digits > 0
return naming_pattern, n_digits


def get_episode_name(self, idx: int) -> Path:
"""
Expand Down

0 comments on commit 937bd46

Please sign in to comment.