Minimal Decision Transformer Implementation written in Jax (Flax). [Reference (minimal torch implementation)]
Set up the environments:
pip install -r requirements.txt
pip install -e .
Example:
# 20k training
CUDA_VISIBLE_DEVICES=0 python scripts/train_dt.py --env halfcheetah --dataset medium
CUDA_VISIBLE_DEVICES=0 python scripts/train_dt.py --env walker2d --dataset medium
CUDA_VISIBLE_DEVICES=0 python scripts/train_dt.py --env hopper --dataset medium
# 100k training
CUDA_VISIBLE_DEVICES=0 python scripts/train_dt.py --env halfcheetah --dataset medium --max_train_iters 20 --policy_save_iters 2 --num_updates_per_iter 5000
CUDA_VISIBLE_DEVICES=0 python scripts/train_dt.py --env walker2d --dataset medium --max_train_iters 20 --policy_save_iters 2 --num_updates_per_iter 5000
CUDA_VISIBLE_DEVICES=0 python scripts/train_dt.py --env hopper --dataset medium --max_train_iters 20 --policy_save_iters 2 --num_updates_per_iter 5000
Test example:
CUDA_VISIBLE_DEVICES=0 python scripts/test_dt.py --env halfcheetah --dataset medium --chk_pt_name dt_halfcheetah-medium-v2/seed_0/22-06-22-06-31-49/model_best.pt
Citation:
@inproceedings{furuta2021generalized,
title={Generalized Decision Transformer for Offline Hindsight Information Matching},
author={Hiroki Furuta and Yutaka Matsuo and Shixiang Shane Gu},
booktitle={International Conference on Learning Representations},
year={2022}
}