Skip to content

Commit

Permalink
- add weight decay(untested)
Browse files Browse the repository at this point in the history
  • Loading branch information
Huanghe committed May 30, 2023
1 parent 13e7dc2 commit 24eee36
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 39 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,7 @@ RWKV-v4neo/wandb/
RWKV-v4neo-lora/out*/
RWKV-v4neo-lora/wandb/
data/
.vscode/
.vscode/

# Pycharm
.idea/
104 changes: 66 additions & 38 deletions RWKV-v4neo-lora/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy

if importlib.util.find_spec('deepspeed'):
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
Expand All @@ -25,15 +26,15 @@
"alpha": 0,
"dropout": 0,
"parts": {"att", "ln", "time"},
"layers":None,
"layers": None,
}


try:
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
except:
os.environ["RWKV_MY_TESTING"] = ''


def __nop(ob):
return ob

Expand All @@ -44,7 +45,6 @@ def __nop(ob):
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method


########################################################################################################
# CUDA Kernel
########################################################################################################
Expand All @@ -57,14 +57,20 @@ def __nop(ob):
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
if os.environ.get("WN_WU_GRAD_CLIP"):
print("Use WN_WU_GRAD_CLIP")
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16_gclip.cpp", "cuda/wkv_cuda_bf16_gclip.cu"],
verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "-O0",
#"--use_fast_math",
# "-O3", "-Xptxas -O3",
# "--extra-device-vectorization",
f"-DTmax={T_MAX}", f"-DMAX_gw=\\({1e5}\\) -DMIN_gw=\\({-1e5}\\) -DMAX_gu=\\({1e5}\\) -DMIN_gu=\\({-1e5}\\)"])
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16_gclip.cpp", "cuda/wkv_cuda_bf16_gclip.cu"],
verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "-O0",
# "--use_fast_math",
# "-O3", "-Xptxas -O3",
# "--extra-device-vectorization",
f"-DTmax={T_MAX}",
f"-DMAX_gw=\\({1e5}\\) -DMIN_gw=\\({-1e5}\\) -DMAX_gu=\\({1e5}\\) -DMIN_gu=\\({-1e5}\\)"])
else:
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"],
verbose=True,
extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math",
"-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])


class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
Expand All @@ -81,6 +87,7 @@ def forward(ctx, B, T, C, w, u, k, v):
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
return y

@staticmethod
def backward(ctx, gy):
B = ctx.B
Expand All @@ -97,11 +104,13 @@ def backward(ctx, gy):
else:
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format,
dtype=torch.bfloat16)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format,
dtype=torch.bfloat16)
if os.environ.get("DEBUG_WKV"):
torch.cuda.set_sync_debug_mode(1)
for name,check in [("gw",gw),("gu",gu),("gk",gk),("gv",gv),("gy",gy)]:
for name, check in [("gw", gw), ("gu", gu), ("gk", gk), ("gv", gv), ("gy", gy)]:
if check.isnan().any():
print(f"Find NaN in {name} before backward")
print(check)
Expand All @@ -110,7 +119,7 @@ def backward(ctx, gy):
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
if os.environ.get("DEBUG_WKV"):
torch.cuda.synchronize()
for name,check in [("gw",gw),("gu",gu),("gk",gk),("gv",gv),("gy",gy)]:
for name, check in [("gw", gw), ("gu", gu), ("gk", gk), ("gv", gv), ("gy", gy)]:
if check.isnan().any():
print(f"Find NaN in {name}")
print(check)
Expand All @@ -121,7 +130,11 @@ def backward(ctx, gy):
gu = torch.sum(gu, dim=0)
return (None, None, None, gw, gu, gk, gv)
else:
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True,
extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3",
"--extra-device-vectorization", f"-DTmax={T_MAX}"])


class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
Expand Down Expand Up @@ -149,6 +162,7 @@ def forward(ctx, B, T, C, w, u, k, v):
return y.half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return y.bfloat16()

@staticmethod
def backward(ctx, gy):
B = ctx.B
Expand Down Expand Up @@ -205,8 +219,8 @@ def __init__(self, in_features: int, out_features: int, bias: bool):

def forward(self, x):
return (
F.linear(x, self.weight) + self.scaling *
F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))
F.linear(x, self.weight) + self.scaling *
F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))


@functools.wraps(LoraLinear)
Expand Down Expand Up @@ -288,7 +302,7 @@ def __init__(self, args, layer_id):
if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
Expand All @@ -309,13 +323,13 @@ def forward(self, x):
def QKV(self, q, k, v):
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.att_mask == 0, float('-inf'))
att = F.softmax(att, dim = -1)
att = F.softmax(att, dim=-1)
x = att @ v
return x

@MyFunction
def jit_funcQKV(self, x):
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
Expand All @@ -338,6 +352,7 @@ def forward(self, x):
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
return rwkv


########################################################################################################

class RWKV_ChannelMix(MyModule):
Expand Down Expand Up @@ -373,6 +388,7 @@ def forward(self, x):
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv


class MishGLU(MyModule):
def __init__(self, args, layer_id):
super().__init__()
Expand Down Expand Up @@ -402,6 +418,7 @@ def forward(self, x):
b = self.bb(xb)
return self.value(a * F.mish(b))


########################################################################################################
# The RWKV Model with our blocks
########################################################################################################
Expand All @@ -419,8 +436,8 @@ def __init__(self, args, layer_id):
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
self.pos_emb_x = nn.Parameter(torch.zeros((1, args.my_pos_emb, args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb, 1, args.n_embd)))

if self.layer_id == 0 and self.args.pre_ffn > 0:
self.ffnPre = RWKV_ChannelMix(args, 0)
Expand All @@ -445,7 +462,7 @@ def forward(self, x, x_emb=None):
if self.layer_id == 0:
x = self.ln0(x)
if args.my_pos_emb > 0:
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
x = x + pos_emb

if self.layer_id == 0 and args.pre_ffn > 0:
Expand All @@ -471,23 +488,23 @@ def forward(ctx, loss, y):
return loss

@staticmethod
def backward(ctx, grad_output): #这个函数会不会影响batch和grad_accu的一致性?感觉上会。梯度累积时,factor变大了。但是只有loss缩放,这里的正则化项反而没有缩放
def backward(ctx, grad_output): # 这个函数会不会影响batch和grad_accu的一致性?感觉上会。梯度累积时,factor变大了。但是只有loss缩放,这里的正则化项反而没有缩放
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1]) #这一行类似crossentropy在token上平均。总感觉mask loss并不能真正保证对应位置不产生loss
factor = 1e-4 / (y.shape[0] * y.shape[1]) # 这一行类似crossentropy在token上平均。总感觉mask loss并不能真正保证对应位置不产生loss
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
if os.environ.get("WN_FIX_L2WRAP"): #实现batch等价性,并且防止对已经较小的值下拉
maxx[maxx<3.]=0. #并不永远向下拉logits,只对大于阈值的往下拉
if os.environ.get("WN_FIX_L2WRAP"): # 实现batch等价性,并且防止对已经较小的值下拉
maxx[maxx < 3.] = 0. # 并不永远向下拉logits,只对大于阈值的往下拉
gy.scatter_(-1, ids, maxx * factor * grad_output)
else:
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
#修改一下
#正向 l2loss = L2(loss,y)
#如果有梯度累积,那么dfinal/dl2loss=1/accumulate
#所以应当在scatter_时加入grad_output
#话说,这个名字为什么叫l2,明明也并不是2范数吧。
# 修改一下
# 正向 l2loss = L2(loss,y)
# 如果有梯度累积,那么dfinal/dl2loss=1/accumulate
# 所以应当在scatter_时加入grad_output
# 话说,这个名字为什么叫l2,明明也并不是2范数吧。


class RWKV(pl.LightningModule):
Expand Down Expand Up @@ -518,6 +535,7 @@ def __init__(self, args):
def configure_optimizers(self):
args = self.args
if args.layerwise_lr > 0:
lr_decay = set()
lr_1x = set()
lr_2x = set()
lr_3x = set()
Expand All @@ -534,23 +552,30 @@ def configure_optimizers(self):
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
elif len(p.squeeze().shape) >= 2:
lr_decay.add(n)
else:
lr_1x.add(n)
lr_decay = sorted(list(lr_decay))
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
# print('decay', lr_decay)
# print('1x', lr_1x)
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},
# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},
# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_decay], "weight_decay": args.weight_decay, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
Expand All @@ -565,12 +590,14 @@ def configure_optimizers(self):
optim_groups = [g for g in optim_groups if len(g["params"]) > 0]

if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps,
bias_correction=True, adamw_mode=True, amsgrad=False)
# return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
if self.args.strategy == 'single_device' and self.args.precision == 'bf16':
return torch.optim.Adam(optim_groups,lr=self.args.lr_init,betas=self.args.betas, eps=self.args.adam_eps)
return torch.optim.Adam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps)
else:
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps,
bias_correction=True, adam_w_mode=True, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)

@property
Expand Down Expand Up @@ -695,7 +722,8 @@ def generate_init_weight(self):
else:
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.",
".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
if kk in n:
scale = 0
if n == "head.weight":
Expand Down
1 change: 1 addition & 0 deletions RWKV-v4neo-lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
parser.add_argument("--accumulate_grad_batches_dict", default=None, type=str)

parser.add_argument("--debug",action="store_true",default=False)
parser.add_argument("--weight_decay", default=0, type=float)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
Expand Down

0 comments on commit 24eee36

Please sign in to comment.