|
1 | 1 | import gc
|
2 | 2 | import logging
|
| 3 | +import os |
3 | 4 | import time
|
4 | 5 | from functools import partial
|
5 | 6 | from typing import Dict, List, Optional, Union
|
|
8 | 9 | import torch.distributed as dist
|
9 | 10 | from torch.utils.data import DataLoader
|
10 | 11 |
|
| 12 | +from internlm.accelerator import AcceleratorType, get_accelerator |
11 | 13 | from internlm.checkpoint.checkpoint_manager import CheckpointManager
|
12 | 14 | from internlm.core.context import global_context as gpc
|
13 | 15 | from internlm.core.context.process_group_initializer import ParallelMode
|
|
31 | 33 | )
|
32 | 34 | from internlm.utils.common import (
|
33 | 35 | BatchSkipper,
|
34 |
| - check_cuda_env, |
35 | 36 | enable_pytorch_expandable_segments,
|
36 | 37 | get_current_device,
|
37 | 38 | get_megatron_flops,
|
|
47 | 48 |
|
48 | 49 | # global llm logger
|
49 | 50 | logger = logging.getLogger(__file__)
|
| 51 | +internlm_accelerator = get_accelerator() |
| 52 | + |
| 53 | + |
| 54 | +def check_cuda_env(): |
| 55 | + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: |
| 56 | + wp_fwd_per = gpc.config.parallel.weight.get("forward_overlap_per", "layer") |
| 57 | + ewp_fwd_per = gpc.config.parallel.expert_weight.get("forward_overlap_per", "layer") |
| 58 | + wp_size = gpc.config.parallel.weight.get("size", 1) |
| 59 | + ewp_size = gpc.config.parallel.expert_weight.get("size", 1) |
| 60 | + open_max_conns = (wp_size == 1 or wp_fwd_per != "layer") and (ewp_size == 1 or ewp_fwd_per != "layer") |
| 61 | + if open_max_conns: |
| 62 | + max_connections = os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") |
| 63 | + assert ( |
| 64 | + max_connections is not None |
| 65 | + ), "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!" |
| 66 | + assert ( |
| 67 | + max_connections == "1" |
| 68 | + ), "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, it should be set to 1!".format(max_connections) |
| 69 | + |
| 70 | + avoid_record_streams = os.getenv("TORCH_NCCL_AVOID_RECORD_STREAMS") |
| 71 | + assert ( |
| 72 | + avoid_record_streams is not None |
| 73 | + ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS has not been set, please set it to 1!" |
| 74 | + assert ( |
| 75 | + avoid_record_streams == "1" |
| 76 | + ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS is set to {}, it should be set to 1!".format(avoid_record_streams) |
50 | 77 |
|
51 | 78 |
|
52 | 79 | class TrainerBuilder(Trainer):
|
|
0 commit comments