Skip to content

Commit

Permalink
Feat/inference extra param (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
sxwl-donggang authored Nov 20, 2024
1 parent 5b7d093 commit c935183
Show file tree
Hide file tree
Showing 14 changed files with 379 additions and 182 deletions.
237 changes: 144 additions & 93 deletions .github/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,51 @@
import requests
import os
import atexit
import sys
import signal
from typing import Optional, Dict, Any, List
from dataclasses import dataclass

# 配置stdout为无缓冲模式,使print立即输出
sys.stdout.reconfigure(line_buffering=True)

# 从环境变量获取基础URL,如果未设置则使用默认值
BASE_URL = os.getenv('SXWL_API_URL', 'https://llm.nascentcore.net')

# 全局变量用于存储需要清理的资源
resources_to_cleanup = {
'inference_services': set(),
'finetune_jobs': set()
}

def cleanup_handler(signum, frame):
"""处理中断信号,清理所有资源"""
print("收到中断信号,开始清理资源...", flush=True)
config = APIConfig.create_default()
client = APIClient(config)

# 清理推理服务
for service_name in resources_to_cleanup['inference_services']:
try:
client.delete_inference_service(service_name)
except Exception as e:
print(f"清理推理服务 {service_name} 失败: {e}", flush=True)

# 清理微调任务
for job_id in resources_to_cleanup['finetune_jobs']:
try:
client.delete_finetune_job(job_id)
except Exception as e:
print(f"清理微调任务 {job_id} 失败: {e}", flush=True)

sys.exit(0)

# 注册信号处理器
signal.signal(signal.SIGINT, cleanup_handler)
signal.signal(signal.SIGTERM, cleanup_handler)
signal.signal(signal.SIGQUIT, cleanup_handler)
signal.signal(signal.SIGHUP, cleanup_handler)

@dataclass
class APIConfig:
base_url: str
Expand All @@ -18,6 +57,9 @@ class APIConfig:
@classmethod
def create_default(cls) -> 'APIConfig':
token = os.environ["SXWL_TOKEN"]
if token and isinstance(token, bytes):
# 如果是 bytes,解码为字符串
token = token.decode('utf-8')
headers = {
'Accept': 'application/json, text/plain, */*',
'Authorization': f'Bearer {token}',
Expand All @@ -44,25 +86,27 @@ def get_models(self) -> List[Dict[str, Any]]:
response = self._make_request('GET', '/resource/models')
data = response.json()
models = data.get('public_list', []) + data.get('user_list', [])
print(f"获取到 {len(models)} 个模型")
print(f"获取到 {len(models)} 个模型", flush=True)
return models
except Exception as e:
print(f"获取模型列表失败: {str(e)}")
print(f"获取模型列表失败: {str(e)}", flush=True)
return []

def delete_inference_service(self, service_name: str) -> None:
try:
self._make_request('DELETE', '/job/inference', params={'service_name': service_name})
print("推理服务删除成功")
print("推理服务删除成功", flush=True)
resources_to_cleanup['inference_services'].discard(service_name)
except Exception as e:
print(f"删除推理服务失败: {str(e)}")
print(f"删除推理服务失败: {str(e)}", flush=True)

def delete_finetune_job(self, finetune_id: str) -> None:
try:
self._make_request('POST', '/userJob/job_del', json={'job_id': finetune_id})
print("微调任务删除成功")
print("微调任务删除成功", flush=True)
resources_to_cleanup['finetune_jobs'].discard(finetune_id)
except Exception as e:
print(f"删除微调任务失败: {str(e)}")
print(f"删除微调任务失败: {str(e)}", flush=True)

class InferenceService:
def __init__(self, client: APIClient):
Expand All @@ -73,8 +117,9 @@ def __init__(self, client: APIClient):
def deploy(self, model_config: Dict[str, Any]) -> None:
response = self.client._make_request('POST', '/job/inference', json=model_config)
self.service_name = response.json()['service_name']
print(f"服务名称: {self.service_name}")
atexit.register(self.client.delete_inference_service, self.service_name)
print(f"服务名称: {self.service_name}", flush=True)
# 添加到需要清理的资源列表
resources_to_cleanup['inference_services'].add(self.service_name)
self._wait_for_ready()

def _wait_for_ready(self, max_retries: int = 60, retry_interval: int = 30) -> None:
Expand All @@ -86,11 +131,11 @@ def _wait_for_ready(self, max_retries: int = 60, retry_interval: int = 30) -> No
if item['service_name'] == self.service_name:
if item['status'] == 'running':
self.api_endpoint = item['api']
print(f"服务已就绪: {item}")
print(f"服务已就绪: {item}", flush=True)
return
break

print(f"服务启动中... ({attempt + 1}/{max_retries})")
print(f"服务启动中... ({attempt + 1}/{max_retries})", flush=True)
time.sleep(retry_interval)

raise TimeoutError("服务启动超时")
Expand All @@ -116,22 +161,23 @@ def __init__(self, client: APIClient):
def start(self, finetune_config: Dict[str, Any]) -> None:
response = self.client._make_request('POST', '/job/finetune', json=finetune_config)
self.job_id = response.json()['job_id']
print(f"微调任务ID: {self.job_id}")
atexit.register(self.client.delete_finetune_job, self.job_id)
print(f"微调任务ID: {self.job_id}", flush=True)
# 添加到需要清理的资源列表
resources_to_cleanup['finetune_jobs'].add(self.job_id)
self._wait_for_completion()
self._get_adapter_id()

def _wait_for_completion(self, max_retries: int = 60, retry_interval: int = 30) -> None:
for _ in range(max_retries):
print(f"正在检查微调任务状态... (第 {_ + 1}/{max_retries} 次尝试)")
print(f"正在检查微调任务状态... (第 {_ + 1}/{max_retries} 次尝试)", flush=True)
response = self.client._make_request('GET', '/job/training',
params={'current': 1, 'size': 1000})

print(f"API响应: {response.json()}")
print(f"API响应: {response.json()}", flush=True)
for job in response.json().get('content', []):
if job['jobName'] == self.job_id:
status = job['status']
print(f"微调状态: {status}")
print(f"微调状态: {status}", flush=True)

if status == 'succeeded':
return
Expand All @@ -150,91 +196,96 @@ def _get_adapter_id(self) -> None:
meta = json.loads(adapter.get('meta', '{}'))
if meta.get('finetune_id') == self.job_id:
self.adapter_id = adapter['id']
print(f"适配器ID: {self.adapter_id}")
print(f"适配器ID: {self.adapter_id}", flush=True)
return
except json.JSONDecodeError:
continue

raise ValueError(f"未找到对应的适配器")

def main():
# 初始化API客户端
config = APIConfig.create_default()
client = APIClient(config)

# 获取可用模型列表
available_models = client.get_models()
# 检查目标模型是否在可用模型列表中
target_model_id = "model-storage-0ce92f029254ff34"
model_found = False
for model in available_models:
if model.get('id') == target_model_id:
model_found = True
print(f"找到目标模型: {model.get('name')}")
break

if not model_found:
raise ValueError(f"未找到ID为 {target_model_id} 的模型")

# 部署基础模型推理服务
base_model = InferenceService(client)
base_model_config = {
"gpu_model": "NVIDIA-GeForce-RTX-3090",
"model_category": "chat",
"gpu_count": 1,
"model_id": model.get('id'),
"model_name":model.get('name'),
"model_size": 15065904829,
"model_is_public": True,
"model_template": "gemma",
"min_instances": 1,
"model_meta": "{\"template\":\"gemma\",\"category\":\"chat\",\"can_finetune\":true,\"can_inference\":true}",
"max_instances": 1
}
base_model.deploy(base_model_config)

# 测试基础模型对话
response = base_model.chat([{"role": "user", "content": "你是谁"}])
print("基础模型响应:", response)

# 开始微调任务
finetune = FinetuneJob(client)
finetune_config = {
"model": model.get('name'),
"training_file": "dataset-storage-f90b82cc7ab88911",
"gpu_model": "NVIDIA-GeForce-RTX-3090",
"gpu_count": 1,
"finetune_type": "lora",
"hyperparameters": {
"n_epochs": "3.0",
"batch_size": "4",
"learning_rate_multiplier": "5e-5"
},
"model_saved_type": "lora",
"model_id": model.get('id'),
"model_name": model.get('name'),
"model_size": model.get('size'),
"model_is_public": model.get('is_public'),
"model_template": model.get('template'),
"model_category": "chat",
"dataset_id": "dataset-storage-f90b82cc7ab88911",
"model_meta": model.get('meta'),
"dataset_name": "llama-factory/alpaca_data_zh_short",
"dataset_size": 14119,
"dataset_is_public": True
}
finetune.start(finetune_config)

# 部署微调后的模型
finetuned_model = InferenceService(client)
finetuned_model_config = {**base_model_config,
"adapter_id": finetune.adapter_id,
"adapter_is_public": False}
finetuned_model.deploy(finetuned_model_config)

# 测试微调后的模型对话
response = finetuned_model.chat([{"role": "user", "content": "你是谁"}])
print("微调后模型响应:", response)
try:
# 初始化API客户端
config = APIConfig.create_default()
client = APIClient(config)

# 获取可用模型列表
available_models = client.get_models()
# 检查目标模型是否在可用模型列表中
target_model_id = "model-storage-0ce92f029254ff34"
model_found = False
for model in available_models:
if model.get('id') == target_model_id:
model_found = True
print(f"找到目标模型: {model.get('name')}", flush=True)
break

if not model_found:
raise ValueError(f"未找到ID为 {target_model_id} 的模型")

# 部署基础模型推理服务
base_model = InferenceService(client)
base_model_config = {
"gpu_model": "NVIDIA-GeForce-RTX-3090",
"model_category": "chat",
"gpu_count": 1,
"model_id": model.get('id'),
"model_name":model.get('name'),
"model_size": 15065904829,
"model_is_public": True,
"model_template": "gemma",
"min_instances": 1,
"model_meta": "{\"template\":\"gemma\",\"category\":\"chat\",\"can_finetune\":true,\"can_inference\":true}",
"max_instances": 1
}
base_model.deploy(base_model_config)

# 测试基础模型对话
response = base_model.chat([{"role": "user", "content": "你是谁"}])
print("基础模型响应:", response, flush=True)

# 开始微调任务
finetune = FinetuneJob(client)
finetune_config = {
"model": model.get('name'),
"training_file": "dataset-storage-f90b82cc7ab88911",
"gpu_model": "NVIDIA-GeForce-RTX-3090",
"gpu_count": 1,
"finetune_type": "lora",
"hyperparameters": {
"n_epochs": "3.0",
"batch_size": "4",
"learning_rate_multiplier": "5e-5"
},
"model_saved_type": "lora",
"model_id": model.get('id'),
"model_name": model.get('name'),
"model_size": model.get('size'),
"model_is_public": model.get('is_public'),
"model_template": model.get('template'),
"model_category": "chat",
"dataset_id": "dataset-storage-f90b82cc7ab88911",
"model_meta": model.get('meta'),
"dataset_name": "llama-factory/alpaca_data_zh_short",
"dataset_size": 14119,
"dataset_is_public": True
}
finetune.start(finetune_config)

# 部署微调后的模型
finetuned_model = InferenceService(client)
finetuned_model_config = {**base_model_config,
"adapter_id": finetune.adapter_id,
"adapter_is_public": False}
finetuned_model.deploy(finetuned_model_config)

# 测试微调后的模型对话
response = finetuned_model.chat([{"role": "user", "content": "你是谁"}])
print("微调后模型响应:", response, flush=True)
except Exception as e:
print(f"发生错误: {e}", flush=True)
cleanup_handler(None, None)
raise

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ jobs:
pip install requests # 添加必要的依赖
- name: Run task
run: python .github/scripts/inference.py
run: timeout --preserve-status --signal=TERM 30m python .github/scripts/inference.py
env:
SXWL_TOKEN: ${{ secrets.AUTHORIZATION }}
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,5 @@ coverage.xml
# 调试二进制文件
__debug_bin*
.secrets

cmd/main.go
2 changes: 2 additions & 0 deletions cpodoperator/api/v1beta1/inference_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type InferenceSpec struct {
GPUCount int `json:"gpuCount,omitempty"`

AutoscalerOptions *AutoscalerOptions `json:"autoscalerOptions,omitempty"`

Params *string `json:"params,omitempty"`
}

type AutoscalerOptions struct {
Expand Down
5 changes: 5 additions & 0 deletions cpodoperator/api/v1beta1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c935183

Please sign in to comment.