Skip to content

Update EVO2 tests according to Hyena arch changes #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated 196 files
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def test_gpu_forward(self, operator: ParallelHyenaOperator):
g = operator.num_groups
dg = operator.group_dim

x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
v = torch.ones((batch_size, seq_len, g, dg), device=device)
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
v = torch.ones((batch_size, (g * dg), seq_len), device=device)

output = operator(x1, x2, v)
assert output.shape[0] == batch_size
assert output.shape[1] == seq_len
assert output.shape[2] == operator.hidden_size
assert output.shape[1] == operator.hidden_size
assert output.shape[2] == seq_len


class TestParallelShortHyenaOperator:
Expand All @@ -89,7 +89,6 @@ def operator(self, transformer_config: TransformerConfig, hyena_config: HyenaCon
init_method="small_init",
short_conv_class=ParallelCausalDepthwiseConv1d,
use_fast_causal_conv=False,
is_mlp=False,
local_init=False,
use_conv_bias=False,
)
Expand All @@ -109,14 +108,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
g = operator.num_groups
dg = operator.group_dim

x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
v = torch.ones((batch_size, seq_len, g, dg), device=device)
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
v = torch.ones((batch_size, (g * dg), seq_len), device=device)

output = operator(x1, x2, v)
assert output.shape[0] == batch_size
assert output.shape[1] == seq_len
assert output.shape[2] == operator.hidden_size
assert output.shape[1] == operator.hidden_size
assert output.shape[2] == seq_len


class TestParallelShortHyenaOperatorWithConvBias:
Expand All @@ -130,7 +129,6 @@ def operator(self, transformer_config: TransformerConfig, hyena_config: HyenaCon
init_method="small_init",
short_conv_class=ParallelCausalDepthwiseConv1d,
use_fast_causal_conv=False,
is_mlp=False,
local_init=False,
use_conv_bias=True,
)
Expand All @@ -150,14 +148,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
g = operator.num_groups
dg = operator.group_dim

x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
v = torch.ones((batch_size, seq_len, g, dg), device=device)
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test somewhere covering that this still works with tensor parallel? It could be that moving sequence to the last dimension breaks tensor parallel because that has a lot of hardcoded assumptions about splitting on axis 1. Maybe if you run the brca notebook but with TP=2 (using the experimental bf16 model weights if doing this on a non fp8 node) and it still works, that would be good? Please post a manual verification to this effect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not aware of any tests for TP. But all the tests in NeMo and BioNeMo are passing. The CI failure now is discussed in this thread and is unrelated to these changes.

I will run the notebook with TP=2 and report the results here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can now confirm that the notebook is reproducing ToT results with TP=2 or CP=2 on two A6000. However, there is a regression in ToT compared to the last time notebook was executed and this is unrelated to changes here (more info regarding ToT regression)

x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
v = torch.ones((batch_size, (g * dg), seq_len), device=device)

output = operator(x1, x2, v)
assert output.shape[0] == batch_size
assert output.shape[1] == seq_len
assert output.shape[2] == operator.hidden_size
assert output.shape[1] == operator.hidden_size
assert output.shape[2] == seq_len


class TestParallelCausalDepthwiseConv1d:
Expand Down
Loading