Skip to content

[Not for landing] piggy back on titan for scale init test #841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: gh/fduwjj/1/base
Choose a base branch
from
2 changes: 1 addition & 1 deletion torchtitan/models/llama/train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description = "Llama 3 405B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
profile_freq = 5

[metrics]
log_freq = 10
Expand Down
37 changes: 35 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,40 @@ def main(job_config: JobConfig):

if __name__ == "__main__":
init_logger()
warmup = False
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT, PrefixStore
# The first one is just for warm up.
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
global_rank = int(os.environ["RANK"])
index = 0
rendezvous_iterator = torch.distributed.rendezvous(
"env://", global_rank, world_size, timeout=_DEFAULT_PG_NCCL_TIMEOUT
)
tcp_store, rank, world_size = next(rendezvous_iterator)
tcp_store.set_timeout(_DEFAULT_PG_NCCL_TIMEOUT)
config = JobConfig()
config.parse_args()
main(config)
torch.distributed.destroy_process_group()
for root_size in [128]:
os.environ["TORCH_NCCL_RANKS_PER_ROOT"] = str(root_size)
iter_size = 10
delta = 0.0
for i in range(iter_size):
start = time.perf_counter()
store = PrefixStore(f"default_pg_{index}", tcp_store)
index += 1
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(store=store, backend="nccl", world_size=world_size, rank=global_rank)
with maybe_enable_profiling(
config, global_step=i
) as torch_profiler:
torch.distributed.barrier()
end = time.perf_counter()
torch.distributed.destroy_process_group()
delta += (end - start)
print(f"Time to init process group: {end - start:.6f} seconds for {root_size} ranks per roots")
if warmup:
print(f"Average time to init process group: {delta / float(iter_size):.6f} seconds for {root_size} ranks per roots")
warmup = True
# main(config)
# torch.distributed.destroy_process_group()
Loading