From 776eb01936b8674fdab5ecfc7b9082b416b90d67 Mon Sep 17 00:00:00 2001 From: Jerry Jiang <83740943+AGENDD@users.noreply.github.com> Date: Fri, 22 Nov 2024 19:35:53 +0800 Subject: [PATCH] Update model.py --- rwkv_pip_package/src/rwkv/model.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/rwkv_pip_package/src/rwkv/model.py b/rwkv_pip_package/src/rwkv/model.py index 7f531b1..80ff76f 100644 --- a/rwkv_pip_package/src/rwkv/model.py +++ b/rwkv_pip_package/src/rwkv/model.py @@ -317,11 +317,14 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = if not ALREADY_CONVERTED: try: # precompute embedding + w['emb.weight_non_norm'] = w['emb.weight'] w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) + except: + w['emb.weight_non_norm'] = w['emb.weight'].float() w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) - del w['blocks.0.ln0.weight'] - del w['blocks.0.ln0.bias'] + # del w['blocks.0.ln0.weight'] + # del w['blocks.0.ln0.bias'] print_need_newline = False @@ -1004,7 +1007,13 @@ def cuda_att_seq_v6_0(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, x_maa, w_maa, k_ma ######################################################################################################## - def forward(self, tokens, state, full_output=False): + def embed(self, tokens): + w = self.w + seq_mode = len(tokens) > 1 + + return w['emb.weight_non_norm'][tokens if seq_mode else tokens[0]] + + def forward(self, tokens, state, embed=None, full_output=False): with torch.no_grad(): w = self.w args = self.args @@ -1033,10 +1042,16 @@ def forward(self, tokens, state, full_output=False): else: state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous() state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + + - seq_mode = len(tokens) > 1 - - x = w['emb.weight'][tokens if seq_mode else tokens[0]] + if(embed != None): + seq_mode = len(embed) > 1 + x = embed + x = F.layer_norm(x, (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) + else: + seq_mode = len(tokens) > 1 + x = w['emb.weight'][tokens if seq_mode else tokens[0]] for i in range(args.n_layer): bbb = f'blocks.{i}.'