Skip to content

Commit

Permalink
reduce train_num_workers if os.getlogin() == 'zfu'
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Dec 15, 2023
1 parent 3794f96 commit ad44c2a
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,11 @@ def load_data(dataset_dir_l, name_filter, camera_names, batch_size_train, batch_
# construct dataset and dataloader
train_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, train_episode_ids, train_episode_len, chunk_size, policy_class)
val_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, val_episode_ids, val_episode_len, chunk_size, policy_class)
train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler_train, pin_memory=True, num_workers=16, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_sampler=batch_sampler_val, pin_memory=True, num_workers=8, prefetch_factor=2)
train_num_workers = (8 if os.getlogin() == 'zfu' else 16) if train_dataset.augment_images else 2
val_num_workers = 8 if train_dataset.augment_images else 2
print(f'Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}')
train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler_train, pin_memory=True, num_workers=train_num_workers, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_sampler=batch_sampler_val, pin_memory=True, num_workers=val_num_workers, prefetch_factor=2)

return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim

Expand Down

0 comments on commit ad44c2a

Please sign in to comment.