forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_adapter.py
106 lines (82 loc) · 3.55 KB
/
test_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from contextlib import redirect_stdout
from dataclasses import asdict
from io import StringIO
from unittest.mock import Mock
import torch
from lightning import Fabric
def test_config_identical():
import lit_gpt.adapter as gpt_adapter
import lit_gpt.model as gpt
name = "pythia-70m"
base_config = asdict(gpt.Config.from_name(name))
adapter_config = asdict(gpt_adapter.Config.from_name(name))
del adapter_config["adapter_prompt_length"]
del adapter_config["adapter_start_layer"]
assert adapter_config == base_config
with Fabric(accelerator="cpu").init_module(empty_init=True):
base_model = gpt.GPT.from_name(name)
adapter_model = gpt_adapter.GPT.from_name(name)
assert adapter_model.lm_head.weight.shape == base_model.lm_head.weight.shape
def test_adapter_filter(tmp_path):
from lit_gpt.adapter import GPT, adapter_filter
fabric = Fabric(devices=1)
model = GPT.from_name("pythia-70m", n_layer=4)
save_path = tmp_path / "model.pth"
fabric.save(save_path, {"model": model}, filter={"model": adapter_filter})
saved = torch.load(save_path)["model"]
expected = {
"transformer.h.2.attn.adapter_wte.weight",
"transformer.h.2.attn.gating_factor",
"transformer.h.3.attn.adapter_wte.weight",
"transformer.h.3.attn.gating_factor",
}
assert set(saved) == expected
def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch):
import finetune.adapter as module
module.gradient_accumulation_iters = 1
module.save_interval = 2
module.eval_interval = 2
module.eval_iters = 2
module.max_iters = 6
data = [
{"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([1, 2, 3])},
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([2, 3, 4])},
]
torch.save(data, tmp_path / "train.pt")
torch.save(data, tmp_path / "test.pt")
from lit_gpt.config import name_to_config
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
monkeypatch.setitem(name_to_config, "tmp", model_config)
load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(module, "lazy_load", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **kwargs: torch.tensor([3, 2, 1], **kwargs)
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
stdout = StringIO()
with redirect_stdout(stdout):
module.setup(data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path, precision="32-true")
assert {p.name for p in tmp_path.glob("*.pth")} == {
"iter-000001-ckpt.pth",
"iter-000003-ckpt.pth",
"iter-000005-ckpt.pth",
"lit_model_adapter_finetuned.pth",
}
assert (tmp_path / "version_0" / "metrics.csv").is_file()
logs = stdout.getvalue()
assert logs.count("optimizer.step") == module.max_iters
assert logs.count("val loss") == module.max_iters // module.eval_interval
assert "of trainable parameters: 168" in logs
def test_adapter_gpt_init_weights():
from lit_gpt.adapter import GPT, Config
config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0)
model = GPT(config)
param = model.transformer.h[0].attn.gating_factor
assert (param == 0).all()
torch.nn.init.constant_(param, 1.23)
assert (param != 0).any()
model.apply(model._init_weights)
assert (param == 0).all()