Skip to content

Commit 87e742d

Browse files
authored
Merge branch 'feature/zerobubble' into feature/zerobubble
2 parents 0aab903 + af6aa9e commit 87e742d

File tree

5 files changed

+13
-14
lines changed

5 files changed

+13
-14
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,8 @@ def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
313313
"""
314314

315315
# Call the superclass backward method to compute gradients.
316-
<<<<<<< HEAD
317316
with self.model._hook_context():
318-
super().backward(loss, *args, **kwargs)
319-
=======
320-
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
321-
>>>>>>> [plugin] hybrid support zero bubble pipeline (#6060)
317+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
322318

323319
if self.model.require_grad_sync:
324320
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -541,12 +537,8 @@ def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
541537
None
542538
"""
543539
# Call the superclass backward method to compute gradients.
544-
<<<<<<< HEAD
545540
with self.model._hook_context():
546-
super().backward(loss, *args, **kwargs)
547-
=======
548-
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
549-
>>>>>>> [plugin] hybrid support zero bubble pipeline (#6060)
541+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
550542

551543
if self.model.require_grad_sync:
552544
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -1174,6 +1166,14 @@ def __init__(
11741166
num_microbatch=num_microbatches,
11751167
microbatch_size=microbatch_size,
11761168
)
1169+
elif pp_style == "zbv":
1170+
self.scheduler = ZeroBubbleVPipeScheduler(
1171+
stage_manager=self.stage_manager,
1172+
schedule=scheduler_nodes,
1173+
num_model_chunks=num_model_chunks,
1174+
num_microbatch=num_microbatches,
1175+
microbatch_size=microbatch_size,
1176+
)
11771177
else:
11781178
raise NotImplementedError()
11791179
if sequence_parallelism_mode == "ring_attn":

colossalai/pipeline/schedule/v_schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,4 +491,4 @@ def even_breaker(x: ScheduledNode):
491491
# print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
492492
# print()
493493

494-
return local_order_with_rollback
494+
return local_order_with_rollback

colossalai/pipeline/schedule/zero_bubble_pp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,4 +902,4 @@ def forward_backward_step(
902902

903903
self.assert_buffer_empty()
904904

905-
return result
905+
return result

colossalai/zero/low_level/low_level_optim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,6 @@ def backward(self, loss, inputs=None, retain_graph=False):
433433

434434
ctx = nullcontext() if self._backward_context is None else self._backward_context()
435435
with ctx:
436-
loss.backward(retain_graph=retain_graph)
437436
loss.backward(inputs=inputs, retain_graph=retain_graph)
438437

439438
if not self.require_grad_sync:

tests/test_pipeline/test_schedule/test_zerobubble_pp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,4 +863,4 @@ def test_pp():
863863

864864

865865
if __name__ == "__main__":
866-
test_pp()
866+
test_pp()

0 commit comments

Comments
 (0)