Skip to content

Achazwl/BMTrain

 
 

Repository files navigation

🚄 BMTrain


Documentation Status GitHub release (latest by date including pre-releases) GitHub

1. 安装

From PyPI (recommended)

$ pip install bmtrain

From source

$ git clone https://github.com/OpenBMB/BMTrain.git
$ cd BMTrain
$ python setup.py install

2. 使用

Step 1: 启用 bmtrain

要使用bmtrain需要在代码中引入bmtrain工具包,并在代码的开头使用bmtrain.init_distributed

import bmtrain as bmt
bmt.init_distributed(
    seed=0,
    # ...
)

注意: 使用bmtrain时请不要使用pytorch自带的distributed模块,包括torch.distributed.init_process_group以及相关通信函数。

Step 2: 使用 ZeRO3 优化

使用ZeRO3优化需要对模型代码进行简单替换:

  • torch.nn.Module -> bmtrain.DistributedModule
  • torch.nn.Parameter -> bmtrain.DistributedParameter

并在合适的模块上使用Checkpointing

原始代码:

import torch
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            SomeTransformerBlock(),
            SomeTransformerBlock(),
            SomeTransformerBlock()
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x

替换后代码:

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x
    

Step 3: 通信优化

为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用TransformerBlockList来进一步优化。 在使用时需要对代码进行简单替换:

  • torch.nn.ModuleList -> bmtrain.TransformerBlockList
  • for module in self.module_list: x = module(x, ...) -> x = self.module_list(x, ...)

原始代码:

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x
    

替换后代码:

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = bmt.TransformerBlockList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        x = self.module_list(x, 1, 2, 3)
        return x
    

Step 4: 运行分布式训练代码

bmtrain支持pytorch原生的分布式训练启动器:

torch.distributed.launch

$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py

torchrun

$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py

更多信息请参考pytorch官方文档:Launch utility

3. 其它说明

BMTrain工具包对pytorch进行了底层修改,如果你的程序输出了意料之外的结果,可以在issue中提交相关信息。

更多例子请参考 examples 文件夹。

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 73.0%
  • C 10.0%
  • C++ 9.7%
  • Cuda 6.7%
  • Other 0.6%