Skip to content

Commit

Permalink
add co-training options: sample_weights, list of dataset_dir, stats_d…
Browse files Browse the repository at this point in the history
…ir, train_ratio
  • Loading branch information
MarkFzp committed Nov 25, 2023
1 parent 30d1722 commit 4228c61
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 38 deletions.
9 changes: 7 additions & 2 deletions imitate_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ def main(args):
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
# num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']
stats_dir = task_config.get('stats_dir', None)
sample_weights = task_config.get('sample_weights', None)
train_ratio = task_config.get('train_ratio', 0.99)
name_filter = task_config.get('name_filter', lambda n: True)

# fixed parameters
Expand Down Expand Up @@ -151,7 +154,7 @@ def main(args):
print()
exit()

train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, name_filter, camera_names, batch_size_train, batch_size_val, args['chunk_size'], args['skip_mirrored_data'], config['load_pretrain'], policy_class)
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, name_filter, camera_names, batch_size_train, batch_size_val, args['chunk_size'], args['skip_mirrored_data'], config['load_pretrain'], policy_class, stats_dir_l=stats_dir, sample_weights=sample_weights, train_ratio=train_ratio)

# save dataset stats
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
Expand Down Expand Up @@ -492,6 +495,8 @@ def train_bc(train_dataloader, val_dataloader, config):
for batch_idx, data in enumerate(val_dataloader):
forward_dict = forward_pass(data, policy)
validation_dicts.append(forward_dict)
if batch_idx > 20:
break

validation_summary = compute_dict_mean(validation_dicts)

Expand Down
97 changes: 61 additions & 36 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,42 @@
import IPython
e = IPython.embed

def flatten_list(l):
return [item for sublist in l for item in sublist]

class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path_list, camera_names, norm_stats, episode_ids, episode_len, chunk_size, policy_class):
def __init__(self, dataset_path_l, camera_names, norm_stats, episode_id_l, episode_len_l, chunk_size, policy_class, sample_weights=None):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_path_list = dataset_path_list
self.episode_id_l = episode_id_l
self.dataset_path_l = dataset_path_l
self.camera_names = camera_names
self.norm_stats = norm_stats
self.episode_len = episode_len
self.episode_len_l = episode_len_l
self.sum_episode_len_l = [sum(episode_len) for episode_len in episode_len_l]
self.chunk_size = chunk_size
self.cumulative_len = np.cumsum(self.episode_len)
self.max_episode_len = max(episode_len)
self.cumulative_len_l = [np.cumsum(episode_len) for episode_len in episode_len_l]
self.max_episode_len = max([max(episode_len) for episode_len in episode_len_l])
self.policy_class = policy_class
self.sample_weights = np.array(sample_weights) / np.sum(sample_weights) if sample_weights is not None else None # None is uniform sampling
self.__getitem__(0) # initialize self.is_sim
self.is_sim = False

# dummy
def __len__(self):
return sum(self.episode_len)

def _locate_transition(self, index):
assert index < self.cumulative_len[-1]
episode_index = np.argmax(self.cumulative_len > index) # argmax returns first True index
start_ts = index - (self.cumulative_len[episode_index] - self.episode_len[episode_index])
episode_id = self.episode_ids[episode_index]
return episode_id, start_ts

def __getitem__(self, index):
episode_id, start_ts = self._locate_transition(index)
dataset_path = self.dataset_path_list[episode_id]
return sum(self.episode_len_l[0])

def _locate_transition(self):
dataset_idx = np.random.choice(len(self.dataset_path_l), p=self.sample_weights)
step_idx = np.random.randint(self.sum_episode_len_l[dataset_idx])

episode_idx = np.argmax(self.cumulative_len_l[dataset_idx] > step_idx) # argmax returns first True index
start_ts = step_idx - (self.cumulative_len_l[dataset_idx][episode_idx] - self.episode_len_l[dataset_idx][episode_idx])
episode_id = self.episode_id_l[dataset_idx][episode_idx]
return dataset_idx, episode_id, start_ts

def __getitem__(self, _):
dataset_idx, episode_id, start_ts = self._locate_transition()
dataset_path = self.dataset_path_l[dataset_idx][episode_id]
try:
# print(dataset_path)
with h5py.File(dataset_path, 'r') as root:
Expand Down Expand Up @@ -82,7 +90,6 @@ def __getitem__(self, index):
is_pad = np.zeros(self.max_episode_len)
is_pad[action_len:] = 1

padded_action = padded_action[:self.chunk_size]
padded_action = padded_action[:self.chunk_size]

# new axis for different cameras
Expand Down Expand Up @@ -180,32 +187,50 @@ def find_all_hdf5(dataset_dir, skip_mirrored_data):
return hdf5_files


def load_data(dataset_dir, name_filter, camera_names, batch_size_train, batch_size_val, chunk_size, skip_mirrored_data=False, load_pretrain=False, policy_class=None):
dataset_path_list = find_all_hdf5(dataset_dir, skip_mirrored_data)
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
num_episodes = len(dataset_path_list)
def load_data(dataset_dir_l, name_filter, camera_names, batch_size_train, batch_size_val, chunk_size, skip_mirrored_data=False, load_pretrain=False, policy_class=None, stats_dir_l=None, sample_weights=None, train_ratio=0.99):
if type(dataset_dir_l) == str:
dataset_dir_l = [dataset_dir_l]
train_dataset_path_l = [find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l]
# val only on dataset_dir_l[0]
val_dataset_path_l = [train_dataset_path_l[0]]
# train_dataset_path_l = [n for n in train_dataset_path_l if name_filter(n)]

num_episodes = [len(dataset_path_list) for dataset_path_list in train_dataset_path_l]

# obtain train test split
train_ratio = 0.995
shuffled_episode_ids = np.random.permutation(num_episodes)
train_episode_ids = shuffled_episode_ids[:int(train_ratio * num_episodes)]
val_episode_ids = shuffled_episode_ids[int(train_ratio * num_episodes):]
print(f'\n\nData from: {dataset_dir}\n- Train on {len(train_episode_ids)} episodes\n- Test on {len(val_episode_ids)} episodes\n\n')
# obtain train val split for dataset_dir_l[0]
num_episodes_0 = num_episodes[0]
shuffled_episode_id_0 = np.random.permutation(num_episodes_0)
train_episode_id_0 = shuffled_episode_id_0[:int(train_ratio * num_episodes_0)]
val_episode_id_0 = shuffled_episode_id_0[int(train_ratio * num_episodes_0):]
train_episode_id_l = [train_episode_id_0]
val_episode_id_l = [val_episode_id_0]
for num_episode in num_episodes[1:]:
train_episode_id_l.append(np.arange(num_episode, dtype=np.int))
print(f'\n\nData from: {dataset_dir_l}\n- Train on {[len(train_episode_id) for train_episode_id in train_episode_id_l]} episodes\n- Test on {[len(val_episode_id) for val_episode_id in val_episode_id_l]} episodes\n\n')

# obtain normalization stats for qpos and action
# if load_pretrain:
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
# norm_stats = pickle.load(f)
# print('Loaded pretrain dataset stats')
norm_stats, all_episode_len = get_norm_stats(dataset_path_list)
train_episode_len = [all_episode_len[i] for i in train_episode_ids]
val_episode_len = [all_episode_len[i] for i in val_episode_ids]
train_episode_len_l = []
val_episode_len_l = []
for idx, (train_dataset_path, train_episode_id) in enumerate(zip(train_dataset_path_l, train_episode_id_l)):
_, all_episode_len = get_norm_stats(train_dataset_path)
train_episode_len_l.append([all_episode_len[i] for i in train_episode_id])
if idx == 0:
val_episode_len_l.append([all_episode_len[i] for i in val_episode_id_0])
if stats_dir_l is None:
stats_dir_l = dataset_dir_l
elif type(stats_dir_l) == str:
stats_dir_l = [stats_dir_l]
norm_stats, _ = get_norm_stats(flatten_list([find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]))

# construct dataset and dataloader
train_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, train_episode_ids, train_episode_len, chunk_size, policy_class)
val_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, val_episode_ids, val_episode_len, chunk_size, policy_class)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
train_dataset = EpisodicDataset(train_dataset_path_l, camera_names, norm_stats, train_episode_id_l, train_episode_len_l, chunk_size, policy_class, sample_weights=sample_weights)
val_dataset = EpisodicDataset(val_dataset_path_l, camera_names, norm_stats, val_episode_id_l, val_episode_len_l, chunk_size, policy_class, sample_weights=None)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=2, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=2, prefetch_factor=2)

return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim

Expand Down

0 comments on commit 4228c61

Please sign in to comment.