-
Notifications
You must be signed in to change notification settings - Fork 0
/
_main_pytorch.py
59 lines (43 loc) · 1.6 KB
/
_main_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
from sequel.utils.loggers.logging import install_logging
from sequel.utils.callbacks.metrics.pytorch_metric_callback import StandardMetricCallback
from sequel.benchmarks import select_benchmark
from sequel.backbones.pytorch import select_backbone, select_optimizer
from sequel.utils.callbacks.tqdm_callback import TqdmCallback
from sequel.utils.loggers.wandb_logger import WandbLogger
from sequel.algos.pytorch import ALGOS
from sequel.utils.utils import set_seed
def without(d, key):
new_d = d.copy()
new_d.pop(key)
return new_d
@hydra.main(config_path="configs", config_name="config", version_base="1.1")
def my_app(config: DictConfig) -> None:
install_logging()
logging.info("The experiment config is:\n" + OmegaConf.to_yaml(config))
logger = WandbLogger(config)
set_seed(config.seed)
mc = StandardMetricCallback()
tq = TqdmCallback()
# initialize benchmark (e.g. SplitMNIST)
benchmark = select_benchmark(config.benchmark)
logging.info(benchmark)
# initialize backbone model (e.g. a CNN, MLP)
backbone = select_backbone(config)
logging.info(backbone)
optimizer = select_optimizer(config, backbone)
algo = ALGOS[config.algo.name.lower()](
**without(dict(config.algo), "name"),
backbone=backbone,
benchmark=benchmark,
optimizer=optimizer,
callbacks=[mc, tq],
loggers=[logger],
)
logging.info(algo)
# start the learning process!
algo.fit(epochs=config.training.epochs_per_task)
if __name__ == "__main__":
my_app()