一个用于训练中文大语言模型的demo代码, 支持预训练、SFT(Supervised Fine-tuning)和 DPO(Direct Preference Optimization)训练
# 安装基础依赖
pip install torch transformers datasets accelerate wandb einops
# 克隆仓库
git clone https://github.com/your-username/chinese-llm-training.git
cd chinese-llm-training
from train_data import prepare_data
# 准备与训练数据
tokenized_datasets, tokenizer = prepare_data(
tokenizer_path="path/to/tokenizer" # 可选
)
[
{
"instruction": "请介绍一下你自己",
"response": "我是一个AI助手,我可以帮助回答问题、解决问题,和进行友好的对话。"
}
]
[
{
"instruction": "请介绍一下你自己",
"chosen": "我是一个AI助手,我会尊重用户并提供有帮助的回答。",
"rejected": "我是最强大的AI,我什么都知道。"
}
]
python train_llm.py
python sft_train.py
python dpo_train.py
所有训练脚本都支持通过 wandb_config 配置训练参数:
wandb_config = {
"d_model": 512, # 模型维度
"n_heads": 8, # 注意力头数
"n_layers": 6, # 层数
"d_ff": 2048, # 前馈网络维度
"batch_size": 32, # 批次大小
"learning_rate": 5e-4, # 学习率
"num_epochs": 3, # 训练轮数
"use_moe": False, # 是否使用MoE
# ... 其他参数
}
chinese-llm-training/
├── models.py # 模型定义
├── tokenizer.py # 分词器实现
├── train_data.py # 预训练数据处理
├── train_llm.py # 预训练脚本
├── sft_data.py # SFT数据处理
├── sft_train.py # SFT训练脚本
├── dpo_data.py # DPO数据处理
├── dpo_train.py # DPO训练脚本
└── README.md
- 基于 Transformer 的自定义 LLM 实现
- 支持 RoPE(Rotary Position Embedding)位置编码
- 可选的 MoE(Mixture of Experts)前馈网络
- 支持 Flash Attention 加速
- 支持梯度检查点(Gradient Checkpointing)
- 预训练(Pre-training)
- 有监督微调(SFT)
- 直接偏好优化(DPO)
- Flash Attention 加速
- 混合精度训练(FP16)
- 梯度累积
- 梯度检查点
- torch.compile 加速
- 多进程数据加载优化
- wandb 集成
- 训练损失追踪
- GPU 内存监控
- 实时生成示例
- 训练进度记录
- 支持 RoPE 位置编码
- 可选的 MoE 层
- Flash Attention 支持
- 灵活的配置选项
- 混合精度训练
- 梯度累积
- 梯度检查点
- Flash Attention
- torch.compile 加速
-
显存优化:
- 使用梯度检查点
- 启用混合精度训练
- 适当的批次大小
- 合理的梯度累积步数
-
训练速度:
- 启用 Flash Attention
- 使用 torch.compile
- 优化数据加载
- 调整 worker 数量
-
模型选择:
- 根据实际需求选择模型大小
- 考虑是否启用 MoE
- 权衡性能和效果
- 确保 CUDA 版本与 PyTorch 匹配
- 预训练前检查数据质量
- SFT 和 DPO 训练需要高质量的数据
- 根据 GPU 显存调整配置参数
- 定期保存检查点
欢迎提交 Issue 和 Pull Request!
MIT License