forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_serve.py
42 lines (34 loc) · 1.6 KB
/
test_serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import asdict
import shutil
from lightning.fabric import seed_everything
from fastapi.testclient import TestClient
from litserve.server import LitServer
import torch
import yaml
from litgpt import GPT, Config
from litgpt.deploy.serve import SimpleLitAPI
from litgpt.scripts.download import download_from_hub
def test_simple(tmp_path):
# Create model checkpoint
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
accelerator = "cpu"
server = LitServer(
SimpleLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1),
accelerator=accelerator, devices=1, timeout=60
)
with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello world"})
# Model is a small random model, not trained, hence the gibberish.
# We are just testing that the server works.
assert response.json()["output"][:19] == " statues CAD pierci"