本文档旨在比较 RWKV 模型的六个不同版本(v1 至 v6),并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。
RWKV v1
- 初始版本,基础实现 RWKV 时间混合和通道混合模块。
- 主要特性:
- 使用时间混合(Time-mix)和通道混合(Channel-mix)模块。
- 采用标准的线性层和嵌入层初始化。
- 使用掩码来处理因果关系。
RWKV v2
- 增强版本,改进了时间混合和通道混合的实现。
- 主要改进:
- 优化了模型加载和状态管理。
- 增加了新的归一化方法。
- 提升了训练和推理效率。
RWKV v3
- 进一步优化的版本,主要集中在性能提升。
- 主要改进:
- 调整了层数和嵌入维度,提供更灵活的配置选项。
- 增加了预处理步骤,提高了推理效率。
RWKV v4
- 增加了对更大规模模型的支持,提升了模型复杂度。
- 主要改进:
- 支持24层和1024维嵌入。
- 增加了更多的参数调优选项。
RWKV v5
- 继续提升模型规模和复杂度,并优化了模型架构。
- 主要改进:
- 支持更高的嵌入维度(2048)。
- 引入了新的时间混合和通道混合方法,提升了模型性能。
RWKV v6
- 最新版本,综合了前几个版本的改进,并引入了一些新特性。
- 主要改进:
- 增加了对更大词汇表(65536)的支持。
- 采用了改进的混合方法,提升了推理速度和准确性。
1. 架构与实现
- 时间混合(Time-Mix)和通道混合(Channel-Mix):
-
v1:基本实现,功能完备。
class RWKV_TimeMix(nn.Module): def __init__(self, config, layer_id): super().__init__() assert config.n_attn % config.n_head == 0 self.layer_id = layer_id self.ctx_len = config.ctx_len self.n_head = config.n_head self.head_size = config.n_attn // config.n_head with torch.no_grad(): # initial time_w curves for better convergence ww = torch.ones(config.n_head, config.ctx_len) curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance for h in range(config.n_head): if h < config.n_head - 1: decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1)) else: decay_speed = 0.0 ww[h] = torch.exp(curve * decay_speed) self.time_w = nn.Parameter(ww) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn) self.output = nn.Linear(config.n_attn, config.n_embd)
-
v2:优化了时间混合和通道混合,提升了计算效率。
class RWKV_ChannelMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd)) hidden_sz = 4 * n_embd self.key = nn.Linear(n_embd, hidden_sz, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(hidden_sz, n_embd, bias=False) def forward(self, x): x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix) k = self.key(x) k = torch.square(torch.relu(k)) kv = self.value(k) rkv = torch.sigmoid(self.receptance(x)) * kv return rkv
-
v3:进一步优化,并增加了灵活的配置选项。
class RWKV_ChannelMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd)) hidden_sz = 4 * n_embd self.key = nn.Linear(n_embd, hidden_sz, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(hidden_sz, n_embd, bias=False) def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv
-
v4:支持更大规模模型,提升了时间混合和通道混合的处理能力。
class RWKV_RNN(torch.jit.ScriptModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): if '.time_' in k: w[k] = w[k].squeeze() if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x} else: w[k] = w[k].float() # convert to f32 type self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k])
-
v5:引入了新的混合方法,进一步提升了性能。
class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): w[k] = w[k].float() # convert to f32 type if '.time_' in k: w[k] = w[k].squeeze() if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1) if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) self.n_head = w['blocks.0.att.time_decay'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k])
-
v6:改进了混合方法,提升了整体性能和效率。
class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): w[k] = w[k].float() # convert to f32 type if '.time_' in k: w[k] = w[k].squeeze() if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) self.n_head = w['blocks.0.att.time_faaaa'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k])
-
2. 模型规模
- 层数和嵌入维度:
- v1:标准配置,适用于基础任务。
- v2:支持12层和768维嵌入。
- v3:提供12层和24层选项,嵌入维度为768和1024。
- v4:支持24层和1024维嵌入。
- v5:嵌入维度增加至2048。
- v6:进一步增加模型复杂度,支持更大词汇表。
3. 性能与效率
- 推理速度和资源消耗:
- v1:基础实现,资源消耗适中。
- v2:优化后,推理速度提升。
- v3:预处理步骤的增加,提高了推理效率。
- v4:更大规模模型下的性能优化。
- v5:新的混合方法提升了推理速度和准确性。
- v6:综合改进,推理速度和资源利用进一步优化。
4. 词汇表和上下文长度
- 词汇表大小和上下文长度支持:
- v1-v4:词汇表大小和上下文长度逐步增加。
- v5:支持更大上下文长度,适应复杂任务。
- v6:支持最大65536的词汇表和更长的上下文长度。
RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6,模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景:
- v1:适用于基础任务和初步研究。
- v2:适用于需要更高效率和优化的任务。
- v3:适用于需要灵活配置和更高性能的应用。
- v4:适用于大规模模型的训练和推理任务。
- v5:适用于需要高精度和高效推理的复杂任务。
- v6:适用于最前沿的研究和应用,提供最高的性能和效率。
每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。