forked from state-spaces/s4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.yaml
87 lines (76 loc) · 3.57 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# @package _global_
defaults:
- _self_
- experiment: base # Specifies model and pipeline, equivalent to next two lines
# - model: s4 # Model backbone
# - pipeline: cifar # Specifies collection of configs, equivalent to next 5 lines
# Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order)
# # - loader: default # Dataloader (e.g. handles batches)
# # - dataset: cifar # Defines the data (x and y pairs)
# # - task: multiclass_classification # Defines loss and metrics
# # - encoder: null # Interface between data and model
# # - decoder: null # Interface between model and targets
- callbacks: # Extra pytorch-lightning features
- base
- checkpoint
- rich # RichProgressBar and RichModelSummary
# Additional arguments used to configure the training loop
# Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer
train:
seed: 0
name: null # optional name for the run to make logging easier
# These three options are used by callbacks (checkpoint, monitor) and scheduler
# Most of them are task dependent and are set by the pipeline
interval: ??? # Should be specified by scheduler. Also used by LR monitor
monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
ema: 0.0 # Moving average model for validation # TODO move into callback
test: False # Test after training
debug: False # Special settings to make debugging more convenient
ignore_warnings: False # Disable python warnings
# These control state passing between batches
state:
mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]
n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context
n_context_eval: ${.n_context} # Context at evaluation time
# Convenience keys to allow grouping runs
ckpt: null # Resume training
optimizer_param_grouping:
bias_weight_decay: False
normalization_weight_decay: False
disable_dataset: False # Disable dataset loading
validate_at_start: false
pretrained_model_path: null # Path to pretrained model
pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible
pretrained_model_state_hook: # Hook called on the loaded model's state_dict
_name_: null
post_init_hook: # After initializing model, call method on model
_name_: null
layer_decay: # Used for ImageNet finetuning
_name_: null
decay: 0.7
# PL 2.0 seems to have gotten rid of the trainer.track_grad_norm flag
# We have a custom Callback (TrackNorms) that implements something similar
track_grad_norms: False
tolerance: # fault tolerance for training on preemptible machines
logdir: ./resume
id: null # must be set to resume training on preemption
# We primarily use wandb so this is moved to top level in the config for convenience
# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging
# If other loggers are added, it would make sense to put this one level lower under train/ or logger/
wandb:
project: hippo
group: ""
job_type: training
mode: online # choices=['online', 'offline', 'disabled']
save_dir: "."
id: null # pass correct id to resume experiment!
# Below options should not need to be specified
# entity: "" # set to name of your wandb team or just remove it
# log_model: False
# prefix: ""
# job_type: "train"
# tags: []
hydra:
run:
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}