-
Notifications
You must be signed in to change notification settings - Fork 5
deepseek r1 running #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
deepseek r1 running #102
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,5 +18,5 @@ | |
"qk_nope_head_dim": 128, | ||
"qk_rope_head_dim": 64, | ||
"v_head_dim": 128, | ||
"dtype": "fp8" | ||
"dtype": "bfloat16" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,8 @@ | |
from dataclasses import dataclass | ||
from typing import Literal | ||
|
||
import jax | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
@@ -382,7 +382,7 @@ def __init__(self, args: ModelArgs): | |
def forward( | ||
self, | ||
x: torch.Tensor, | ||
start_pos: int, | ||
input_pos: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
mask: torch.Tensor | None, | ||
): | ||
|
@@ -399,7 +399,6 @@ def forward( | |
torch.Tensor: Output tensor with the same shape as the input. | ||
""" | ||
bsz, seqlen, _ = x.size() | ||
end_pos = start_pos + seqlen | ||
q = self.wq(x) if self.q_lora_rank == 0 else self.wq_b(self.q_norm(self.wq_a(x))) | ||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) | ||
q_nope, q_pe = torch.split( | ||
|
@@ -417,12 +416,9 @@ def forward( | |
) | ||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) | ||
self.k_cache[:bsz, start_pos:end_pos] = k | ||
self.v_cache[:bsz, start_pos:end_pos] = v | ||
scores = ( | ||
torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) | ||
* self.softmax_scale | ||
) | ||
# self.k_cache[:bsz, start_pos:end_pos] = k | ||
# self.v_cache[:bsz, start_pos:end_pos] = v | ||
scores = torch.einsum("bshd,bthd->bsht", q, k) * self.softmax_scale | ||
else: | ||
wkv_b = ( | ||
self.wkv_b.weight | ||
|
@@ -431,19 +427,22 @@ def forward( | |
) | ||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) | ||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]) | ||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) | ||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) | ||
# self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) | ||
# self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) | ||
Comment on lines
+430
to
+431
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have changed how we do caching as seen in the bellow lines. Do we have a reason to keep these comments? |
||
kv_cache = self.kv_norm(kv) | ||
pe_cache = k_pe.squeeze(2) | ||
scores = ( | ||
torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) | ||
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) | ||
torch.einsum("bshc,btc->bsht", q_nope, kv_cache) | ||
+ torch.einsum("bshr,btr->bsht", q_pe, pe_cache) | ||
) * self.softmax_scale | ||
if mask is not None: | ||
scores += mask.unsqueeze(1) | ||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) | ||
if attn_impl == "naive": | ||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) | ||
x = torch.einsum("bsht,bthd->bshd", scores, v) | ||
else: | ||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) | ||
kv_cache = self.kv_norm(kv) | ||
x = torch.einsum("bsht,btc->bshc", scores, kv_cache) | ||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :]) | ||
x = self.wo(x.flatten(2)) | ||
return x | ||
|
@@ -544,6 +543,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
else: | ||
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) | ||
indices = group_scores.topk(self.topk_groups, dim=-1)[1] | ||
print('i am here') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test print. Please remove. If appropriate, we could use some logging at info level. I see there are a couple other prints on model functions. I would consider applying the same criteria to those. |
||
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) | ||
scores = (scores * mask.unsqueeze(-1)).flatten(1) | ||
indices = torch.topk(scores, self.topk, dim=-1)[1] | ||
|
@@ -590,71 +590,79 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
||
|
||
class MoE(nn.Module): | ||
""" | ||
Mixture-of-Experts (MoE) module. | ||
|
||
Attributes: | ||
dim (int): Dimensionality of input features. | ||
n_routed_experts (int): Total number of experts in the model. | ||
n_local_experts (int): Number of experts handled locally in distributed systems. | ||
n_activated_experts (int): Number of experts activated for each input. | ||
gate (nn.Module): Gating mechanism to route inputs to experts. | ||
experts (nn.ModuleList): List of expert modules. | ||
shared_experts (nn.Module): Shared experts applied to all inputs. | ||
""" | ||
|
||
def __init__(self, args: ModelArgs): | ||
""" | ||
Initializes the MoE module. | ||
|
||
Args: | ||
args (ModelArgs): Model arguments containing MoE parameters. | ||
""" | ||
class ConditionalFeedForward(torch.nn.Module): | ||
def __init__(self, config): | ||
super().__init__() | ||
self.dim = args.dim | ||
assert args.n_routed_experts % world_size == 0 | ||
self.n_routed_experts = args.n_routed_experts | ||
self.n_local_experts = args.n_routed_experts // world_size | ||
self.n_activated_experts = args.n_activated_experts | ||
self.experts_start_idx = rank * self.n_local_experts | ||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts | ||
self.gate = Gate(args) | ||
self.experts = nn.ModuleList( | ||
[ | ||
Expert(args.dim, args.moe_inter_dim) | ||
if self.experts_start_idx <= i < self.experts_end_idx | ||
else None | ||
for i in range(self.n_routed_experts) | ||
] | ||
# TODO(How to enable quantization?) | ||
self.w1 = nn.Parameter( | ||
torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim) | ||
) | ||
self.w2 = nn.Parameter( | ||
torch.empty(config.n_routed_experts, config.dim, config.moe_inter_dim) | ||
) | ||
self.w3 = nn.Parameter( | ||
torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim) | ||
) | ||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) | ||
self.config = config | ||
|
||
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: | ||
return self.forward_for_long_seq_len(x, expert_indices) | ||
|
||
def forward_for_long_seq_len(self, x, expert_indices): | ||
seqlen = x.shape[0] | ||
self.w1.shape[0] | ||
|
||
# e = total num of exp = 8 | ||
# t = seqlen | ||
# o = config.imtermediate size | ||
# i = config.dim | ||
with jax.named_scope("conditional_ff"): | ||
x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1)) | ||
x3 = torch.einsum("ti, eoi-> teo", x, self.w3) | ||
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) | ||
# e = 8; need to reduce to 2 | ||
seq_indexes = torch.arange(seqlen, device=x.device).unsqueeze(1) | ||
return expert_outs[seq_indexes, expert_indices] | ||
|
||
|
||
class MoE(torch.nn.Module): | ||
def __init__(self, model_args) -> None: | ||
super().__init__() | ||
self.dim = model_args.dim | ||
self.model_args = model_args | ||
# assert args.n_routed_experts % world_size == 0 | ||
# self.n_routed_experts = args.n_routed_experts | ||
# self.n_local_experts = args.n_routed_experts // world_size | ||
# self.n_activated_experts = args.n_activated_experts | ||
# self.experts_start_idx = rank * self.n_local_experts | ||
# self.experts_end_idx = self.experts_start_idx + self.n_local_experts | ||
self.gate = Gate(model_args) | ||
# self.experts = nn.ModuleList( | ||
# [ | ||
# Expert(args.dim, args.moe_inter_dim) | ||
# if self.experts_start_idx <= i < self.experts_end_idx | ||
# else None | ||
# for i in range(self.n_routed_experts) | ||
# ] | ||
# ) | ||
Comment on lines
+633
to
+647
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have reason to leave these as comments rather than removing them? |
||
self.shared_experts = MLP( | ||
model_args.dim, model_args.n_shared_experts * model_args.moe_inter_dim | ||
) | ||
self.cond_ffn = ConditionalFeedForward(model_args) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass for the MoE module. | ||
|
||
Args: | ||
x (torch.Tensor): Input tensor. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor after expert routing and computation. | ||
""" | ||
shape = x.size() | ||
bsz, seq, hidden = x.shape | ||
# [B, T, D], combine BT, for prefill B = 1, for decode, T = 1 | ||
x = x.view(-1, self.dim) | ||
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts | ||
# x: [T, D] | ||
self.gate(x) # [T, E] | ||
weights, indices = self.gate(x) | ||
y = torch.zeros_like(x) | ||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() | ||
for i in range(self.experts_start_idx, self.experts_end_idx): | ||
if counts[i] == 0: | ||
continue | ||
expert = self.experts[i] | ||
idx, top = torch.where(indices == i) | ||
y[idx] += expert(x[idx]) * weights[idx, top, None] | ||
z = self.shared_experts(x) | ||
if world_size > 1: | ||
dist.all_reduce(y) | ||
return (y + z).view(shape) | ||
expert_outs = self.cond_ffn(x, indices) | ||
expert_outs = torch.einsum("tai,ta -> ti", expert_outs, weights) | ||
# Changes back to [B, T, D] | ||
expert_outs = expert_outs.reshape(bsz, seq, hidden) | ||
return expert_outs | ||
|
||
|
||
class Block(nn.Module): | ||
|
@@ -687,7 +695,7 @@ def __init__(self, layer_id: int, args: ModelArgs): | |
def forward( | ||
self, | ||
x: torch.Tensor, | ||
start_pos: int, | ||
input_pos: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
mask: torch.Tensor | None, | ||
) -> torch.Tensor: | ||
|
@@ -703,7 +711,7 @@ def forward( | |
Returns: | ||
torch.Tensor: Output tensor after block computation. | ||
""" | ||
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) | ||
x = x + self.attn(self.attn_norm(x), input_pos, freqs_cis, mask) | ||
x = x + self.ffn(self.ffn_norm(x)) | ||
return x | ||
|
||
|
@@ -728,9 +736,6 @@ def __init__(self, args: ModelArgs): | |
Args: | ||
args (ModelArgs): Model arguments containing transformer parameters. | ||
""" | ||
global world_size, rank | ||
world_size = dist.get_world_size() if dist.is_initialized() else 1 | ||
rank = dist.get_rank() if dist.is_initialized() else 0 | ||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 | ||
super().__init__() | ||
self.max_seq_len = args.max_seq_len | ||
|
@@ -742,10 +747,10 @@ def __init__(self, args: ModelArgs): | |
self.head = ColumnParallelLinear( | ||
args.dim, args.vocab_size, dtype=torch.get_default_dtype() | ||
) | ||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) | ||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=True) | ||
|
||
@torch.inference_mode() | ||
def forward(self, tokens: torch.Tensor, start_pos: int = 0): | ||
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): | ||
""" | ||
Forward pass for the Transformer model. | ||
|
||
|
@@ -756,18 +761,14 @@ def forward(self, tokens: torch.Tensor, start_pos: int = 0): | |
Returns: | ||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size). | ||
""" | ||
seqlen = tokens.size(1) | ||
tokens.size(1) | ||
h = self.embed(tokens) | ||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | ||
freqs_cis = self.freqs_cis[input_pos] | ||
mask = None | ||
if seqlen > 1: | ||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) | ||
# if seqlen > 1: | ||
# mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) | ||
Comment on lines
+768
to
+769
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have reason to leave these as comments rather than removing them? |
||
for layer in self.layers: | ||
h = layer(h, start_pos, freqs_cis, mask) | ||
h = layer(h, input_pos, freqs_cis, mask) | ||
h = self.norm(h)[:, -1] | ||
logits = self.head(h) | ||
if world_size > 1: | ||
all_logits = [torch.empty_like(logits) for _ in range(world_size)] | ||
dist.all_gather(all_logits, logits) | ||
logits = torch.cat(all_logits, dim=-1) | ||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment bellow