Skip to content

Commit 6e7163c

Browse files
fix(linear.py): linear module uneven split is forbidden (#374)
1 parent aee457c commit 6e7163c

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

internlm/initialize/launch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,17 @@ def args_sanity_check():
395395
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
396396
if gpc.config.parallel["tensor"].get("mode", None) is None:
397397
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
398+
assert (
399+
gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0
400+
), "VOCAB_SIZE must be integer multiple of tensor parallel size"
398401
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
399402
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp"
400403
assert (
401404
torch.__version__ >= "2.1.0"
402405
), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}"
406+
assert (
407+
gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0
408+
), "VOCAB_SIZE must be integer multiple of wp size"
403409

404410
assert gpc.config.parallel["tensor"].get("mode", None) in [
405411
TensorParallelMode.mtp.name,

internlm/model/modules/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,8 +602,10 @@ def __init__(
602602
split_features = out_features if split_mode == "column" else in_features
603603
multiple = split_features // multiple_of
604604
# We want to split @multiple across world_size, but it could be an uneven split
605+
# uneven split is forbidden
605606
div = multiple // world_size
606607
mod = multiple % world_size
608+
assert mod == 0, "linear module uneven split is forbidden"
607609
# The first @mod ranks get @div + 1 copies, the rest get @div copies
608610
local_multiple = div + int(rank < mod)
609611

tests/test_training/test_forward_output_no_fa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
TOTAL_STEPS = 1
3232
config = Config(
3333
dict(
34+
VOCAB_SIZE=92544,
3435
parallel=dict(
3536
zero1=dict(size=-1),
3637
tensor=dict(size=1, mode="mtp"),

0 commit comments

Comments
 (0)