From 8df68a97a44c435397052922888243f064ad1ea4 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Sat, 18 May 2024 00:13:58 +0000 Subject: [PATCH 1/4] Add debug prints --- llama/model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index f6e6d4f2a..de9795c89 100755 --- a/llama/model.py +++ b/llama/model.py @@ -122,7 +122,8 @@ def __init__(self, args: ModelArgs, world_size: Optional[int] = None, rank: Optional[int] = None, - groups: Optional[List] = None): + groups: Optional[List] = None, + layer_id: int = None): super().__init__() self.n_local_heads = args.n_heads @@ -181,6 +182,8 @@ def forward( mask: Optional[torch.Tensor], input_indexes: torch.Tensor, ): + if self.layer_id == 0: + print(1, id(self.cache_k)) bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -192,9 +195,12 @@ def forward( self.cache_k = self.cache_k.index_copy(1, input_indexes, xk) self.cache_v = self.cache_v.index_copy(1, input_indexes, xv) - + if self.layer_id == 0: + print(2, id(self.cache_k)) keys = self.cache_k[:, :] values = self.cache_v[:, :] + if self.layer_id == 0: + print(3, id(self.cache_k)) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -202,6 +208,8 @@ def forward( xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) keys = keys.transpose(1, 2) + if self.layer_id == 0: + print(4, id(self.cache_k)) values = values.transpose(1, 2) #scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.einsum('ijkl,ijml->ijkm', xq, keys) @@ -287,6 +295,7 @@ def __init__(self, world_size=world_size, rank=rank, groups=groups, + layer_id=layer_id, ) self.feed_forward = FeedForward( dim=args.dim, From e07f0bf1ccbbe2f3b9061e20e2882c581c5db27c Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 20 May 2024 20:18:17 +0000 Subject: [PATCH 2/4] Updated indexx_copy to in-place index_copy_ --- llama/model.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/llama/model.py b/llama/model.py index de9795c89..b0863028e 100755 --- a/llama/model.py +++ b/llama/model.py @@ -122,8 +122,7 @@ def __init__(self, args: ModelArgs, world_size: Optional[int] = None, rank: Optional[int] = None, - groups: Optional[List] = None, - layer_id: int = None): + groups: Optional[List] = None): super().__init__() self.n_local_heads = args.n_heads @@ -182,8 +181,6 @@ def forward( mask: Optional[torch.Tensor], input_indexes: torch.Tensor, ): - if self.layer_id == 0: - print(1, id(self.cache_k)) bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -193,14 +190,10 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - self.cache_k = self.cache_k.index_copy(1, input_indexes, xk) - self.cache_v = self.cache_v.index_copy(1, input_indexes, xv) - if self.layer_id == 0: - print(2, id(self.cache_k)) + self.cache_k = self.cache_k.index_copy_(1, input_indexes, xk) + self.cache_v = self.cache_v.index_copy_(1, input_indexes, xv) keys = self.cache_k[:, :] values = self.cache_v[:, :] - if self.layer_id == 0: - print(3, id(self.cache_k)) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -208,8 +201,6 @@ def forward( xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) keys = keys.transpose(1, 2) - if self.layer_id == 0: - print(4, id(self.cache_k)) values = values.transpose(1, 2) #scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.einsum('ijkl,ijml->ijkm', xq, keys) @@ -295,7 +286,6 @@ def __init__(self, world_size=world_size, rank=rank, groups=groups, - layer_id=layer_id, ) self.feed_forward = FeedForward( dim=args.dim, From 1a0a426c885e24ab88dc7da08303d2934b7d9e68 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 20 May 2024 21:39:45 +0000 Subject: [PATCH 3/4] Remove assignment --- llama/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index b0863028e..67bbefe4c 100755 --- a/llama/model.py +++ b/llama/model.py @@ -190,8 +190,9 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - self.cache_k = self.cache_k.index_copy_(1, input_indexes, xk) - self.cache_v = self.cache_v.index_copy_(1, input_indexes, xv) + self.cache_k.index_copy_(1, input_indexes, xk) + self.cache_v.index_copy_(1, input_indexes, xv) + keys = self.cache_k[:, :] values = self.cache_v[:, :] From 2b1217ee58d95a3afb5ceb4d9332b5d54cd6a729 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 20 May 2024 21:45:03 +0000 Subject: [PATCH 4/4] remove extra space --- llama/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama/model.py b/llama/model.py index 67bbefe4c..44962740b 100755 --- a/llama/model.py +++ b/llama/model.py @@ -192,7 +192,7 @@ def forward( self.cache_k.index_copy_(1, input_indexes, xk) self.cache_v.index_copy_(1, input_indexes, xv) - + keys = self.cache_k[:, :] values = self.cache_v[:, :]