forked from ai4co/rl4co
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
145 lines (112 loc) · 5.1 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import List, Optional, Sequence, Tuple
import hydra
import lightning as L
import pyrootutils
import torch
from lightning import Callback, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True)
from rl4co import utils
log = utils.get_pylogger(__name__)
@utils.task_wrapper
def run(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
# Note that the RL environment is instantiated inside the model
log.info(f"Instantiating task <{cfg.task._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.task, cfg, _recursive_=False)
if cfg.get("transfer"):
from rl4co.utils.lightning import load_model_from_checkpoint
from rl4co.utils.transfer import transplant_weights
log.info(f"load pretrained model")
device = model.device
pretrained_model = load_model_from_checkpoint(
cfg.transfer.source.config, cfg.transfer.source.checkpoint_path, device=device
)
transplant_weights(pretrained_model, model, **cfg.transfer.transfer_config)
del pretrained_model
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
# Configure DDP automatically
n_devices = cfg.trainer.get("devices", 1)
if isinstance(n_devices, Sequence):
n_devices = len(n_devices)
if n_devices > 1 and cfg.trainer.get("strategy", None) is None:
log.info("Configuring DDP strategy automatically")
cfg.trainer.strategy = dict(
_target_="lightning.pytorch.strategies.DDPStrategy",
find_unused_parameters=True, # We set to True due to RL envs
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
)
# Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision(cfg.get("matmul_precision", "medium"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
if cfg.trainer.get("reload_dataloaders_every_n_epochs", 1) != 1:
log.warning(
"We must reload dataloaders every epoch for RL training. Ignoring reload_dataloaders_every_n_epochs key in trainer."
)
reload_dataloaders_every_n_epochs = 1
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs,
)
object_dict = {
"cfg": cfg,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
if cfg.get("compile", False):
log.info("Compiling model!")
model = torch.compile(model)
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, ckpt_path=cfg.get("ckpt_path"))
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="configs", config_name="main.yaml")
# @hydra.main(version_base="1.3", config_path="configs", config_name="experiment/tsp/am-ppo.yaml")
def main(cfg: DictConfig) -> Optional[float]:
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
# train the model
metric_dict, _ = run(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()