Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: support dataloader resume by skip_first_batches #416

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions opensora/train/train_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@
logger = get_logger(__name__)

class ProgressInfo:
def __init__(self, global_step, train_loss=0.0):
def __init__(self, global_step, local_step, train_loss=0.0):
self.global_step = global_step
self.local_step = local_step # used for dataloader resume
self.train_loss = train_loss

#################################################################################
Expand Down Expand Up @@ -413,16 +414,17 @@ def load_model_hook(models, input_dir):
logger.info(f" Total optimization steps = {args.max_train_steps}")
logger.info(f" Total trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B")
global_step = 0
local_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
# Get the most recent checkpoint, dir format: checkpoint-{global_step}-{local_step}
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = [d for d in dirs if d.startswith("checkpoint") and len(d.split("-")) == 3]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None

Expand All @@ -436,6 +438,7 @@ def load_model_hook(models, input_dir):
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
local_step = int(path.split("-")[2])

initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
Expand All @@ -453,7 +456,7 @@ def load_model_hook(models, input_dir):
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_info = ProgressInfo(global_step, train_loss=0.0)
progress_info = ProgressInfo(global_step, local_step, train_loss=0.0)

def sync_gradients_info(loss):
# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down Expand Up @@ -493,7 +496,7 @@ def sync_gradients_info(loss):
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)

save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}")
save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}-{progress_info.local_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

Expand Down Expand Up @@ -696,12 +699,29 @@ def preprocess_x_for_inpaint(x):
return False

def train_all_epoch(prof_=None):
# resume last epoch by skipping `local_step` batches
# https://huggingface.co/docs/accelerate/usage_guides/checkpoint
if progress_info.local_step != 0:
# https://github.com/huggingface/accelerate/issues/2823
train_dataloader.set_epoch(first_epoch)

logger.info(f"resume dataloader, global_step: {progress_info.global_step}, skip_first_batches: {progress_info.local_step}")
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, progress_info.local_step)
for step, data_item in enumerate(skipped_dataloader):
if train_one_step(step, data_item, prof_):
break

first_epoch += 1

# continue remaining epoch
for epoch in range(first_epoch, args.num_train_epochs):
progress_info.local_step = 0
progress_info.train_loss = 0.0
if progress_info.global_step >= args.max_train_steps:
return True

for step, data_item in enumerate(train_dataloader):
progress_info.local_step += 1
if train_one_step(step, data_item, prof_):
break

Expand All @@ -726,9 +746,9 @@ def train_all_epoch(prof_=None):
skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f"{profile_output_path}/")
) as prof:
train_all_epoch(prof)
train_all_epoch(first_epoch, prof)
else:
train_all_epoch()
train_all_epoch(first_epoch)
accelerator.wait_for_everyone()
accelerator.end_training()
if npu_config is not None and get_sequence_parallel_state():
Expand Down
36 changes: 28 additions & 8 deletions opensora/train/train_t2v_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def log_validation(args, model, vae, text_encoder, tokenizer, accelerator, weigh


class ProgressInfo:
def __init__(self, global_step, train_loss=0.0):
def __init__(self, global_step, local_step, train_loss=0.0):
self.global_step = global_step
self.local_step = local_step # used for dataloader resume
self.train_loss = train_loss


Expand Down Expand Up @@ -501,16 +502,17 @@ def load_model_hook(models, input_dir):
logger.info(f" Total optimization steps = {args.max_train_steps}")
logger.info(f" Total training parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B")
global_step = 0
local_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
# Get the most recent checkpoint, dir format: checkpoint-{global_step}-{local_step}
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = [d for d in dirs if d.startswith("checkpoint") and len(d.split("-")) == 3]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None

Expand All @@ -524,6 +526,7 @@ def load_model_hook(models, input_dir):
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
local_step = int(path.split("-")[2])

initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
Expand All @@ -541,7 +544,7 @@ def load_model_hook(models, input_dir):
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_info = ProgressInfo(global_step, train_loss=0.0)
progress_info = ProgressInfo(global_step, local_step, train_loss=0.0)

def sync_gradients_info(loss):
# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down Expand Up @@ -581,7 +584,7 @@ def sync_gradients_info(loss):
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)

save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}")
save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}-{progress_info.local_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

Expand Down Expand Up @@ -784,13 +787,30 @@ def train_one_step(step_, data_item_, prof_=None):

return False

def train_all_epoch(prof_=None):
def train_all_epoch(first_epoch, prof_=None):
# resume last epoch by skipping `local_step` batches
# https://huggingface.co/docs/accelerate/usage_guides/checkpoint
if progress_info.local_step != 0:
# https://github.com/huggingface/accelerate/issues/2823
train_dataloader.set_epoch(first_epoch)

logger.info(f"resume dataloader, global_step: {progress_info.global_step}, skip_first_batches: {progress_info.local_step}")
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, progress_info.local_step)
for step, data_item in enumerate(skipped_dataloader):
if train_one_step(step, data_item, prof_):
break

first_epoch += 1

# continue remaining epoch
for epoch in range(first_epoch, args.num_train_epochs):
progress_info.local_step = 0
progress_info.train_loss = 0.0
if progress_info.global_step >= args.max_train_steps:
return True

for step, data_item in enumerate(train_dataloader):
progress_info.local_step += 1
if train_one_step(step, data_item, prof_):
break

Expand All @@ -815,9 +835,9 @@ def train_all_epoch(prof_=None):
skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f"{profile_output_path}/")
) as prof:
train_all_epoch(prof)
train_all_epoch(first_epoch, prof)
else:
train_all_epoch()
train_all_epoch(first_epoch)
accelerator.wait_for_everyone()
accelerator.end_training()
if get_sequence_parallel_state():
Expand Down