-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclsft.yaml
42 lines (36 loc) · 1.09 KB
/
clsft.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# @package _global_
defaults:
# - override /trainer: ddp
- override /model: smart
model:
model_config:
lr: 5e-5
lr_min_ratio: 0.05
token_processor:
map_token_sampling: # open-loop
num_k: 1 # for k nearest neighbors
temp: 1.0 # uniform sampling
agent_token_sampling: # closed-loop
num_k: 1 # for k nearest neighbors
temp: 1.0
training_rollout_sampling:
criterium: topk_prob_sampled_with_dist # {topk_dist_sampled_with_prob, topk_prob, topk_prob_sampled_with_dist}
num_k: 32 # for k nearest neighbors, set to -1 to turn-off closed-loop training
temp: 1e-5 # catk = topk_prob_sampled_with_dist with temp=1e-5
training_loss:
use_gt_raw: true
gt_thresh_scale_length: -1 # {"veh": 4.8, "cyc": 2.0, "ped": 1.0}
label_smoothing: 0.0
rollout_as_gt: false
finetune: true
ckpt_path: BC_PRETRAINED_MODEL.ckpt
trainer:
limit_train_batches: 1.0
limit_val_batches: 50
check_val_every_n_epoch: 1
data:
train_batch_size: 10
val_batch_size: 10
test_batch_size: 10
num_workers: 10
action: finetune