Skip to content
/ rl4co Public
forked from ai4co/rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)

License

Notifications You must be signed in to change notification settings

jil8885/rl4co

Repository files navigation

Description

Code repository for RL4CO. Based on TorchRL and the Lightning-Hydra-Template best practices.

How to run

Colone project and install dependencies:

git clone https://github.com/kaist-silab/rl4co && cd rl4co
pip install light-the-torch && python3 -m light_the_torch install --upgrade -r requirements.txt

The above script will automatically install PyTorch with the right GPU version for your system. Alternatively, you can use pip install -r requirements.txt. Alternatively, you can install the package locally with pip install -e ..

Train model with default configuration (AM on TSP environment):

python run.py  

Train model with chosen experiment configuration from configs/experiment/

# Change experiment (e.g. tsp/am, and environment with 42 cities)
python run.py experiment=tsp/am env.num_loc=42    

# Disable logging
python run.py experiment=tsp/am logger='null'

# Create a sweep over hyperparameters (-m for multirun)
python run.py -m experiment=tsp/am  train.optimizer.lr=1e-3,1e-4,1e-5

Testing

Run tests with pytest:

pytest tests/test_*.py

We will enable automated tests when we make the repo public.

Project structure

Note: general layout, may be subject to change


├── .github                   <- Github Actions workflows
│
├── configs                   <- Hydra configs
│   ├── callbacks                <- Callbacks configs
│   ├── data                     <- Data configs
│   ├── debug                    <- Debugging configs
│   ├── experiment               <- Experiment configs
│   ├── extras                   <- Extra utilities configs
│   ├── hparams_search           <- Hyperparameter search configs
│   ├── hydra                    <- Hydra configs
│   ├── local                    <- Local configs
│   ├── logger                   <- Logger configs
│   ├── model                    <- Model configs
│   ├── paths                    <- Project paths configs
│   ├── trainer                  <- Trainer configs
│   │
│   ├── eval.yaml             <- Main config for evaluation
│   └── train.yaml            <- Main config for training
│
├── data                   <- Project data
│
├── logs                   <- Logs generated by hydra and lightning loggers
│
├── notebooks              <- Jupyter notebooks. Naming convention is a number (for ordering),
│                             the creator's initials, and a short `-` delimited description,
│                             e.g. `1.0-jqp-initial-data-exploration.ipynb`.
│
├── scripts                <- Shell scripts
│
├── rl4co                        <- Source code # NOTE: WIP
│   ├── envs                     <- RL environments
│   ├── models                   <- Model scripts
│   ├── rl                       <- RL algorithms
│   ├── utils                    <- Utility scripts
│   │
│   ├── eval.py                  <- Run evaluation
│   └── train.py                 <- Run training
│
├── tests                  <- Tests of any kind
│
├── .env.example              <- Example of file for storing private environment variables
├── .gitignore                <- List of files ignored by git
├── .pre-commit-config.yaml   <- Configuration of pre-commit hooks for code formatting
├── .project-root             <- File for inferring the position of project root directory
├── Makefile                  <- Makefile with commands like `make train` or `make test`
├── pyproject.toml            <- Configuration options for testing and linting
├── requirements.txt          <- File for installing python dependencies
├── setup.py                  <- File for installing project as a package
└── README.md

About

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%