Skip to content

deepseek r1 running #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"tensorboard-plugin-profile==2.18.0",
"tf_keras==2.18.0",
"protobuf==4.25.5",
"fire",
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"dtype": "fp8"
"dtype": "bfloat16"
}
181 changes: 91 additions & 90 deletions torchprime/experimental/torchax_models/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from dataclasses import dataclass
from typing import Literal

import jax
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn

Expand Down Expand Up @@ -382,7 +382,7 @@ def __init__(self, args: ModelArgs):
def forward(
self,
x: torch.Tensor,
start_pos: int,
input_pos: torch.Tensor,
freqs_cis: torch.Tensor,
mask: torch.Tensor | None,
):
Expand All @@ -399,7 +399,6 @@ def forward(
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
q = self.wq(x) if self.q_lora_rank == 0 else self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(
Expand All @@ -417,12 +416,9 @@ def forward(
)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = (
torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos])
* self.softmax_scale
)
# self.k_cache[:bsz, start_pos:end_pos] = k
# self.v_cache[:bsz, start_pos:end_pos] = v
Comment on lines +419 to +420
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment bellow

scores = torch.einsum("bshd,bthd->bsht", q, k) * self.softmax_scale
else:
wkv_b = (
self.wkv_b.weight
Expand All @@ -431,19 +427,22 @@ def forward(
)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
# self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
# self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
Comment on lines +430 to +431
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have changed how we do caching as seen in the bellow lines. Do we have a reason to keep these comments?

kv_cache = self.kv_norm(kv)
pe_cache = k_pe.squeeze(2)
scores = (
torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos])
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
torch.einsum("bshc,btc->bsht", q_nope, kv_cache)
+ torch.einsum("bshr,btr->bsht", q_pe, pe_cache)
) * self.softmax_scale
if mask is not None:
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
if attn_impl == "naive":
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
x = torch.einsum("bsht,bthd->bshd", scores, v)
else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
kv_cache = self.kv_norm(kv)
x = torch.einsum("bsht,btc->bshc", scores, kv_cache)
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :])
x = self.wo(x.flatten(2))
return x
Expand Down Expand Up @@ -544,6 +543,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
print('i am here')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test print. Please remove. If appropriate, we could use some logging at info level.

I see there are a couple other prints on model functions. I would consider applying the same criteria to those.

mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
scores = (scores * mask.unsqueeze(-1)).flatten(1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
Expand Down Expand Up @@ -590,71 +590,79 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class MoE(nn.Module):
"""
Mixture-of-Experts (MoE) module.

Attributes:
dim (int): Dimensionality of input features.
n_routed_experts (int): Total number of experts in the model.
n_local_experts (int): Number of experts handled locally in distributed systems.
n_activated_experts (int): Number of experts activated for each input.
gate (nn.Module): Gating mechanism to route inputs to experts.
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""

def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.

Args:
args (ModelArgs): Model arguments containing MoE parameters.
"""
class ConditionalFeedForward(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList(
[
Expert(args.dim, args.moe_inter_dim)
if self.experts_start_idx <= i < self.experts_end_idx
else None
for i in range(self.n_routed_experts)
]
# TODO(How to enable quantization?)
self.w1 = nn.Parameter(
torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim)
)
self.w2 = nn.Parameter(
torch.empty(config.n_routed_experts, config.dim, config.moe_inter_dim)
)
self.w3 = nn.Parameter(
torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim)
)
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
self.config = config

def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
return self.forward_for_long_seq_len(x, expert_indices)

def forward_for_long_seq_len(self, x, expert_indices):
seqlen = x.shape[0]
self.w1.shape[0]

# e = total num of exp = 8
# t = seqlen
# o = config.imtermediate size
# i = config.dim
with jax.named_scope("conditional_ff"):
x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1))
x3 = torch.einsum("ti, eoi-> teo", x, self.w3)
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2)
# e = 8; need to reduce to 2
seq_indexes = torch.arange(seqlen, device=x.device).unsqueeze(1)
return expert_outs[seq_indexes, expert_indices]


class MoE(torch.nn.Module):
def __init__(self, model_args) -> None:
super().__init__()
self.dim = model_args.dim
self.model_args = model_args
# assert args.n_routed_experts % world_size == 0
# self.n_routed_experts = args.n_routed_experts
# self.n_local_experts = args.n_routed_experts // world_size
# self.n_activated_experts = args.n_activated_experts
# self.experts_start_idx = rank * self.n_local_experts
# self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(model_args)
# self.experts = nn.ModuleList(
# [
# Expert(args.dim, args.moe_inter_dim)
# if self.experts_start_idx <= i < self.experts_end_idx
# else None
# for i in range(self.n_routed_experts)
# ]
# )
Comment on lines +633 to +647
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have reason to leave these as comments rather than removing them?

self.shared_experts = MLP(
model_args.dim, model_args.n_shared_experts * model_args.moe_inter_dim
)
self.cond_ffn = ConditionalFeedForward(model_args)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.

Args:
x (torch.Tensor): Input tensor.

Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size()
bsz, seq, hidden = x.shape
# [B, T, D], combine BT, for prefill B = 1, for decode, T = 1
x = x.view(-1, self.dim)
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# x: [T, D]
self.gate(x) # [T, E]
weights, indices = self.gate(x)
y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
z = self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return (y + z).view(shape)
expert_outs = self.cond_ffn(x, indices)
expert_outs = torch.einsum("tai,ta -> ti", expert_outs, weights)
# Changes back to [B, T, D]
expert_outs = expert_outs.reshape(bsz, seq, hidden)
return expert_outs


class Block(nn.Module):
Expand Down Expand Up @@ -687,7 +695,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
def forward(
self,
x: torch.Tensor,
start_pos: int,
input_pos: torch.Tensor,
freqs_cis: torch.Tensor,
mask: torch.Tensor | None,
) -> torch.Tensor:
Expand All @@ -703,7 +711,7 @@ def forward(
Returns:
torch.Tensor: Output tensor after block computation.
"""
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
x = x + self.attn(self.attn_norm(x), input_pos, freqs_cis, mask)
x = x + self.ffn(self.ffn_norm(x))
return x

Expand All @@ -728,9 +736,6 @@ def __init__(self, args: ModelArgs):
Args:
args (ModelArgs): Model arguments containing transformer parameters.
"""
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
super().__init__()
self.max_seq_len = args.max_seq_len
Expand All @@ -742,10 +747,10 @@ def __init__(self, args: ModelArgs):
self.head = ColumnParallelLinear(
args.dim, args.vocab_size, dtype=torch.get_default_dtype()
)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=True)

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
"""
Forward pass for the Transformer model.

Expand All @@ -756,18 +761,14 @@ def forward(self, tokens: torch.Tensor, start_pos: int = 0):
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
seqlen = tokens.size(1)
tokens.size(1)
h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.freqs_cis[input_pos]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
# if seqlen > 1:
# mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
Comment on lines +768 to +769
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have reason to leave these as comments rather than removing them?

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = layer(h, input_pos, freqs_cis, mask)
h = self.norm(h)[:, -1]
logits = self.head(h)
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
Loading