Skip to content

How can I train the script on the 48G gpu? or How to enable mixed‑precision (autocast) to train on a 48 GB GPU? #57

Open
@WangzcBruce

Description

@WangzcBruce

I’m trying to train using a single 48 GB GPU but still run into memory issues. My launch command is:
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32 HF_ENDPOINT='https://hf-mirror.com' \ python train.py \ --base configs/example/nusc_train.yaml \ --num_nodes 1 \ --n_devices 1 \ --low_vram
Environment:

PyTorch version: (e.g. 2.0.1)

CUDA version: (e.g. 11.8)

GPU: 48 GB (NVIDIA A100 / 48 GB)

yaml:
`model:
base_learning_rate: 5.e-5
target: vwm.models.diffusion.DiffusionEngine
params:
use_ema: True
input_key: img_seq
scale_factor: 0.18215
disable_first_stage_autocast: False
en_and_decode_n_samples_a_time: 1
num_frames: &num_frames 2
slow_spatial_layers: True
train_peft_adapters: False
replace_cond_frames: &replace_cond_frames True
fixed_cond_frames: # only used for logging images
- [ 0, 1, 2 ]
denoiser_config:
target: vwm.modules.diffusionmodules.denoiser.Denoiser
params:
num_frames: *num_frames
scaling_config:
target: vwm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: vwm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [ 3, 1, 1 ]
add_lora: False
action_control: True
conditioner_config:
target: vwm.modules.GeneralConditioner
params:
emb_models:
- input_key: cond_frames_without_noise
is_trainable: False
ucg_rate: 0.15
target: vwm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: vwm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True

      - input_key: fps_id
        is_trainable: False
        ucg_rate: 0.0
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: 256
      - input_key: motion_bucket_id
        is_trainable: False
        ucg_rate: 0.0
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: 256
      - input_key: cond_frames
        is_trainable: False
        ucg_rate: 0.15
        target: vwm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
        params:
          disable_encoder_autocast: False
          n_cond_frames: 1
          n_copies: 1
          is_ae: True
          encoder_config:
            target: vwm.models.autoencoder.AutoencoderKLModeOnly
            params:
              embed_dim: 4
              monitor: val/rec_loss
              ddconfig:
                attn_type: vanilla-xformers
                double_z: True
                z_channels: 4
                resolution: 256
                in_channels: 3
                out_ch: 3
                ch: 128
                ch_mult: [ 1, 2, 4, 4 ]
                num_res_blocks: 2
                attn_resolutions: [ ]
                dropout: 0.0
              loss_config:
                target: torch.nn.Identity
      - input_key: cond_aug
        is_trainable: False
        ucg_rate: 0.0
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: 256
      - input_key: command
        is_trainable: False
        ucg_rate: 0.15
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: &action_emb_dim 128
          num_features: 1
          add_sequence_dim: True
      - input_key: trajectory
        is_trainable: False
        ucg_rate: 0.15
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: *action_emb_dim
          num_features: 8
          add_sequence_dim: True
      - input_key: speed
        is_trainable: False
        ucg_rate: 0.15
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: *action_emb_dim
          num_features: 4
          add_sequence_dim: True
      - input_key: angle
        is_trainable: False
        ucg_rate: 0.15
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: *action_emb_dim
          num_features: 4
          add_sequence_dim: True
      - input_key: goal
        is_trainable: False
        ucg_rate: 0.15
        target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: *action_emb_dim
          num_features: 2
          add_sequence_dim: True
first_stage_config:
  target: vwm.models.autoencoder.AutoencodingEngine
  params:
    loss_config:
      target: torch.nn.Identity
    regularizer_config:
      target: vwm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
    encoder_config:
      target: vwm.modules.diffusionmodules.model.Encoder
      params:
        attn_type: vanilla
        double_z: True
        z_channels: 4
        resolution: 256
        in_channels: 3
        out_ch: 3
        ch: 128
        ch_mult: [ 1, 2, 4, 4 ]
        num_res_blocks: 2
        attn_resolutions: [ ]
        dropout: 0.0
    decoder_config:
      target: vwm.modules.autoencoding.temporal_ae.VideoDecoder
      params:
        attn_type: vanilla
        double_z: True
        z_channels: 4
        resolution: 256
        in_channels: 3
        out_ch: 3
        ch: 128
        ch_mult: [ 1, 2, 4, 4 ]
        num_res_blocks: 2
        attn_resolutions: [ ]
        dropout: 0.0
        video_kernel_size: [ 3, 1, 1 ]
scheduler_config:
  target: vwm.lr_scheduler.LambdaLinearScheduler
  params:
    warm_up_steps: [ 1000 ]
    cycle_lengths: [ 10000000000000 ]
    f_start: [ 1.e-6 ]
    f_max: [ 1. ]
    f_min: [ 1. ]
loss_fn_config:
  target: vwm.modules.diffusionmodules.loss.StandardDiffusionLoss
  params:
    use_additional_loss: True
    offset_noise_level: 0.02
    additional_loss_weight: 0.1
    num_frames: *num_frames
    replace_cond_frames: *replace_cond_frames
    cond_frames_choices:
      - [ ]
      - [ 0 ]
      - [ 0, 1 ]
      - [ 0, 1, 2 ]
    sigma_sampler_config:
      target: vwm.modules.diffusionmodules.sigma_sampling.EDMSampling
      params:
        p_mean: 1.0
        p_std: 1.6
        num_frames: *num_frames
    loss_weighting_config:
      target: vwm.modules.diffusionmodules.loss_weighting.VWeighting
sampler_config:
  target: vwm.modules.diffusionmodules.sampling.EulerEDMSampler
  params:
    num_steps: 15
    discretization_config:
      target: vwm.modules.diffusionmodules.discretizer.EDMDiscretization
      params:
        sigma_max: 700.0
    guider_config:
      target: vwm.modules.diffusionmodules.guiders.LinearPredictionGuider
      params:
        num_frames: *num_frames
        max_scale: 3.0
        min_scale: 1.5

data:
target: vwm.data.dataset.Sampler
params:
batch_size: 1
num_workers: 1
subsets:
- NuScenes
probs:
- 1
samples_per_epoch: 16000
target_height: 64
target_width: 64
num_frames: *num_frames
lightning:
callbacks:
image_logger:
target: train.ImageLogger
params:
num_frames: *num_frames
disabled: False
enable_autocast: True
batch_frequency: 100
increase_log_steps: True
log_first_step: False
log_images_kwargs:
N: *num_frames
modelcheckpoint:
params:
every_n_epochs: 1 # every_n_train_steps: 5000, set the same as image_logger batch_frequency
trainer:
devices: 0
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 100
strategy: deepspeed_stage_2
gradient_clip_val: 0.3
`

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions