forked from Om-Alve/smolGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
38 lines (31 loc) · 843 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from dataclasses import dataclass
import torch
@dataclass
class GPTConfig:
block_size: int = 512
vocab_size: int = 4096
n_layer: int = 8
n_head: int = 8
n_embed: int = 512
dropout: float = 0.2
bias: bool = False
@dataclass
class TrainingConfig:
learning_rate: float = 6e-4
max_iters: int = 30000
weight_decay: float = 1e-1
beta1: float = 0.9
beta2: float = 0.95
grad_clip: float = 1.0
decay_lr: bool = True
warmup_iters: int = 1000
lr_decay_iters: int = 30000
min_lr: float = 6e-5
eval_interval: int = 100
log_interval: int = 10
eval_iters: int = 200
gradient_accumulation_steps: int = 4
batch_size: int = 64
device: str = str(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
dtype: str = "bfloat16"
compile: bool = True