-
Notifications
You must be signed in to change notification settings - Fork 3
/
dvc.yaml
108 lines (108 loc) · 2.92 KB
/
dvc.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
stages:
prepare:
cmd: python scripts/prepare.py data/raw
deps:
- data/raw/File1.txt
- data/raw/File2.txt
- data/raw/File3.txt
- data/raw/File4.txt
- data/raw/File5.txt
- data/raw/File6.txt
- data/raw/SME and Residential allocations.txt
- scripts/prepare.py
params:
- scripts/params.yaml:
- prepare.seed
- prepare.subset
outs:
- data/prepared/data.csv
features:
cmd: python scripts/features.py data/prepared/data.csv
deps:
- data/prepared/data.csv
- scripts/features.py
params:
- scripts/params.yaml:
- features.generate
outs:
- data/features/data.csv
split:
cmd: python scripts/split.py data/features/data.csv
deps:
- data/features/data.csv
- scripts/split.py
params:
- scripts/params.yaml:
- split.test_size
outs:
- data/split/test.csv
- data/split/train.csv
stats:
cmd: python scripts/validate_data.py data/split/test.csv data/split/train.csv
deps:
- data/split/test.csv
- data/split/train.csv
- scripts/validate_data.py
outs:
- data/stats/test.csv
- data/stats/test_max_load
- data/stats/train.csv
- data/stats/train_max_load
fit:
foreach:
- feed_forward_bernstein_flow
- feed_forward_gaussian_mixture_model
- feed_forward_normal_distribution
- feed_forward_quantile_regression
- wavenet_bernstein_flow
- wavenet_gaussian_mixture_model
- wavenet_normal_distribution
- wavenet_quantile_regression
do:
cmd: python scripts/train.py configs/${item}.yaml configs/shared/data_loader_kwds.yaml logs
deps:
- conda_env.yaml
- setup.py
- configs/${item}.yaml
- configs/shared
- data/split/train.csv
- data/stats/train.csv
- data/stats/train_max_load
- src
- scripts/train.py
params:
- scripts/params.yaml:
- test_mode
plots:
- logs/${item}/log.csv
outs:
- logs/${item}/mcp
evaluate:
foreach:
- feed_forward_bernstein_flow
- feed_forward_gaussian_mixture_model
- feed_forward_normal_distribution
- feed_forward_quantile_regression
- wavenet_bernstein_flow
- wavenet_gaussian_mixture_model
- wavenet_normal_distribution
- wavenet_quantile_regression
- baseline
do:
cmd: python scripts/evaluate.py configs/${item}.yaml configs/shared/data_loader_kwds.yaml logs metrics/${item}.yaml
deps:
- conda_env.yaml
- setup.py
- configs/${item}.yaml
- configs/shared
- data/split/test.csv
- data/stats/train.csv
- data/stats/train_max_load
- logs/${item}
- scripts/evaluate.py
- src
params:
- scripts/params.yaml:
- test_mode
metrics:
- metrics/${item}.yaml