Skip to content

Latest commit

 

History

History
 
 

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 

RWKV 模型版本比较报告

本文档旨在比较 RWKV 模型的六个不同版本(v1 至 v6),并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。


版本概述

RWKV v1

  • 初始版本,基础实现 RWKV 时间混合和通道混合模块。
  • 主要特性:
    • 使用时间混合(Time-mix)和通道混合(Channel-mix)模块。
    • 采用标准的线性层和嵌入层初始化。
    • 使用掩码来处理因果关系。

RWKV v2

  • 增强版本,改进了时间混合和通道混合的实现。
  • 主要改进:
    • 优化了模型加载和状态管理。
    • 增加了新的归一化方法。
    • 提升了训练和推理效率。

RWKV v3

  • 进一步优化的版本,主要集中在性能提升。
  • 主要改进:
    • 调整了层数和嵌入维度,提供更灵活的配置选项。
    • 增加了预处理步骤,提高了推理效率。

RWKV v4

  • 增加了对更大规模模型的支持,提升了模型复杂度。
  • 主要改进:
    • 支持24层和1024维嵌入。
    • 增加了更多的参数调优选项。

RWKV v5

  • 继续提升模型规模和复杂度,并优化了模型架构。
  • 主要改进:
    • 支持更高的嵌入维度(2048)。
    • 引入了新的时间混合和通道混合方法,提升了模型性能。

RWKV v6

  • 最新版本,综合了前几个版本的改进,并引入了一些新特性。
  • 主要改进:
    • 增加了对更大词汇表(65536)的支持。
    • 采用了改进的混合方法,提升了推理速度和准确性。

详细比较

1. 架构与实现

  • 时间混合(Time-Mix)和通道混合(Channel-Mix)
    • v1:基本实现,功能完备。

      class RWKV_TimeMix(nn.Module):
          def __init__(self, config, layer_id):
              super().__init__()
              assert config.n_attn % config.n_head == 0
              self.layer_id = layer_id
              self.ctx_len = config.ctx_len
              self.n_head = config.n_head
              self.head_size = config.n_attn // config.n_head
      
              with torch.no_grad(): # initial time_w curves for better convergence
                  ww = torch.ones(config.n_head, config.ctx_len)
                  curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
                  for h in range(config.n_head):
                      if h < config.n_head - 1:
                          decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
                      else:
                          decay_speed = 0.0
                      ww[h] = torch.exp(curve * decay_speed)
              self.time_w = nn.Parameter(ww)
              self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
              self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
              self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
                      
              self.time_shift = nn.ZeroPad2d((0,0,1,-1))
      
              self.key = nn.Linear(config.n_embd, config.n_attn)
              self.value = nn.Linear(config.n_embd, config.n_attn)
              self.receptance = nn.Linear(config.n_embd, config.n_attn)
      
              self.output = nn.Linear(config.n_attn, config.n_embd)
    • v2:优化了时间混合和通道混合,提升了计算效率。

      class RWKV_ChannelMix(nn.Module):
          def __init__(self, layer_id):
              super().__init__()
              self.layer_id = layer_id
      
              self.time_shift = nn.ZeroPad2d((0,0,1,-1))
              self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
      
              hidden_sz = 4 * n_embd
              self.key = nn.Linear(n_embd, hidden_sz, bias=False)
              self.receptance = nn.Linear(n_embd, n_embd, bias=False)
              self.value = nn.Linear(hidden_sz, n_embd, bias=False)
      
          def forward(self, x):
              x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
              k = self.key(x)
              k = torch.square(torch.relu(k))
              kv = self.value(k)
              rkv = torch.sigmoid(self.receptance(x)) * kv
              return rkv
    • v3:进一步优化,并增加了灵活的配置选项。

      class RWKV_ChannelMix(nn.Module):
          def __init__(self, layer_id):
              super().__init__()
              self.layer_id = layer_id
      
              self.time_shift = nn.ZeroPad2d((0,0,1,-1))
              self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
              self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
              
              hidden_sz = 4 * n_embd
              self.key = nn.Linear(n_embd, hidden_sz, bias=False)
              self.receptance = nn.Linear(n_embd, n_embd, bias=False)
              self.value = nn.Linear(hidden_sz, n_embd, bias=False)
      
          def forward(self, x):
              xx = self.time_shift(x)
              xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
              xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
              k = self.key(xk)
              k = torch.square(torch.relu(k))
              kv = self.value(k)
              rkv = torch.sigmoid(self.receptance(xr)) * kv
              return rkv
    • v4:支持更大规模模型,提升了时间混合和通道混合的处理能力。

      class RWKV_RNN(torch.jit.ScriptModule):
          def __init__(self, args):
              super().__init__()
              self.args = args
              self.eval() # set torch to inference mode
              
              w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
              for k in w.keys():
                  if      '.time_' in k: w[k] = w[k].squeeze()
                  if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
                  else: w[k] = w[k].float() # convert to f32 type
              self.w = types.SimpleNamespace() # set self.w from w
              self.w.blocks = {}
              for k in w.keys():
                  parts = k.split('.')
                  last = parts.pop()
                  here = self.w
                  for p in parts:
                      if p.isdigit():
                          p = int(p)
                          if p not in here: here[p] = types.SimpleNamespace()
                          here = here[p]
                      else:
                          if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                          here = getattr(here, p)
                  setattr(here, last, w[k])
    • v5:引入了新的混合方法,进一步提升了性能。

      class RWKV_RNN(MyModule):
          def __init__(self, args):
              super().__init__()
              self.args = args
              self.eval() # set torch to inference mode
              
              w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
              for k in w.keys():
                  w[k] = w[k].float() # convert to f32 type
                  if      '.time_' in k: w[k] = w[k].squeeze()
                  if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
                  if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
      
              self.n_head = w['blocks.0.att.time_decay'].shape[0]
              self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
              
              self.w = types.SimpleNamespace() # set self.w from w
              self.w.blocks = {}
              for k in w.keys():
                  parts = k.split('.')
                  last = parts.pop()
                  here = self.w
                  for p in parts:
                      if p.isdigit():
                          p = int(p)
                          if p not in here: here[p] = types.SimpleNamespace()
                          here = here[p]
                      else:
                          if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                          here = getattr(here, p)
                  setattr(here, last, w[k])
    • v6:改进了混合方法,提升了整体性能和效率。

      class RWKV_RNN(MyModule):
          def __init__(self, args):
              super().__init__()
              self.args = args
              self.eval() # set torch to inference mode
              
              w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
              for k in w.keys():
                  w[k] = w[k].float() # convert to f32 type
                  if      '.time_' in k: w[k] = w[k].squeeze()
                  if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
              self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
              self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
              self.w = types.SimpleNamespace() # set self.w from w
              self.w.blocks = {}
              for k in w.keys():
                  parts = k.split('.')
                  last = parts.pop()
                  here = self.w
                  for p in parts:
                      if p.isdigit():
                          p = int(p)
                          if p not in here: here[p] = types.SimpleNamespace()
                          here = here[p]
                      else:
                          if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                          here = getattr(here, p)
                  setattr(here, last, w[k])

2. 模型规模

  • 层数和嵌入维度
    • v1:标准配置,适用于基础任务。
    • v2:支持12层和768维嵌入。
    • v3:提供12层和24层选项,嵌入维度为768和1024。
    • v4:支持24层和1024维嵌入。
    • v5:嵌入维度增加至2048。
    • v6:进一步增加模型复杂度,支持更大词汇表。

3. 性能与效率

  • 推理速度和资源消耗
    • v1:基础实现,资源消耗适中。
    • v2:优化后,推理速度提升。
    • v3:预处理步骤的增加,提高了推理效率。
    • v4:更大规模模型下的性能优化。
    • v5:新的混合方法提升了推理速度和准确性。
    • v6:综合改进,推理速度和资源利用进一步优化。

4. 词汇表和上下文长度

  • 词汇表大小和上下文长度支持
    • v1-v4:词汇表大小和上下文长度逐步增加。
    • v5:支持更大上下文长度,适应复杂任务。
    • v6:支持最大65536的词汇表和更长的上下文长度。

总结

RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6,模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景:

  • v1:适用于基础任务和初步研究。
  • v2:适用于需要更高效率和优化的任务。
  • v3:适用于需要灵活配置和更高性能的应用。
  • v4:适用于大规模模型的训练和推理任务。
  • v5:适用于需要高精度和高效推理的复杂任务。
  • v6:适用于最前沿的研究和应用,提供最高的性能和效率。

每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。