File tree Expand file tree Collapse file tree 3 files changed +9
-0
lines changed
Expand file tree Collapse file tree 3 files changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -395,11 +395,17 @@ def args_sanity_check():
395395 gpc .config .parallel ["tensor" ] = dict (size = gpc .config .parallel ["tensor" ], mode = TensorParallelMode .mtp .name )
396396 if gpc .config .parallel ["tensor" ].get ("mode" , None ) is None :
397397 gpc .config .parallel ["tensor" ]["mode" ] = TensorParallelMode .mtp .name
398+ assert (
399+ gpc .config .VOCAB_SIZE % gpc .config .parallel .tensor .size == 0
400+ ), "VOCAB_SIZE must be integer multiple of tensor parallel size"
398401 if gpc .config .parallel ["tensor" ]["mode" ] == TensorParallelMode .isp .name :
399402 assert not gpc .config .parallel .zero1 .fsdp , "FSDP does not support isp"
400403 assert (
401404 torch .__version__ >= "2.1.0"
402405 ), f"requires torch>=2.1.0 when using isp but current version is { torch .__version__ } "
406+ assert (
407+ gpc .config .VOCAB_SIZE % gpc .config .parallel .weight .size == 0
408+ ), "VOCAB_SIZE must be integer multiple of wp size"
403409
404410 assert gpc .config .parallel ["tensor" ].get ("mode" , None ) in [
405411 TensorParallelMode .mtp .name ,
Original file line number Diff line number Diff line change @@ -602,8 +602,10 @@ def __init__(
602602 split_features = out_features if split_mode == "column" else in_features
603603 multiple = split_features // multiple_of
604604 # We want to split @multiple across world_size, but it could be an uneven split
605+ # uneven split is forbidden
605606 div = multiple // world_size
606607 mod = multiple % world_size
608+ assert mod == 0 , "linear module uneven split is forbidden"
607609 # The first @mod ranks get @div + 1 copies, the rest get @div copies
608610 local_multiple = div + int (rank < mod )
609611
Original file line number Diff line number Diff line change 3131TOTAL_STEPS = 1
3232config = Config (
3333 dict (
34+ VOCAB_SIZE = 92544 ,
3435 parallel = dict (
3536 zero1 = dict (size = - 1 ),
3637 tensor = dict (size = 1 , mode = "mtp" ),
You can’t perform that action at this time.
0 commit comments