Skip to content

Commit 0aab903

Browse files
flybird11111pre-commit-ci[bot]duanjunwen
committed
[plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <[email protected]>
1 parent f345f5d commit 0aab903

File tree

15 files changed

+146
-49
lines changed

15 files changed

+146
-49
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ jobs:
140140
141141
- name: Install Colossal-AI
142142
run: |
143-
BUILD_EXT=1 pip install -v -e .
143+
BUILD_EXT=1 pip install -v .
144144
pip install --no-cache-dir -r requirements/requirements-test.txt
145145
146146
- name: Store Colossal-AI Cache

.github/workflows/build_on_schedule.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
if: steps.check-avai.outputs.avai == 'true'
5656
run: |
5757
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
58-
BUILD_EXT=1 pip install -v -e .
58+
BUILD_EXT=1 pip install -v .
5959
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
6060
pip install --no-cache-dir -r requirements/requirements-test.txt
6161

colossalai/amp/naive_amp/mixed_precision_mixin/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def zero_grad(self):
4343
dtype: torch.dtype
4444

4545
@abstractmethod
46-
def pre_backward(self, loss: Tensor) -> Tensor:
46+
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
4747
"""Called before backward.
4848
4949
Args:

colossalai/amp/naive_amp/mixed_precision_optimizer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,18 @@ def __init__(
8585
master_params.append(master_p)
8686
group["params"] = master_params
8787

88-
def backward(self, loss: Tensor, *args, **kwargs):
88+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
8989
loss = self.mixed_precision.pre_backward(loss)
90-
loss.backward(*args, **kwargs)
90+
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
9191

92-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
92+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
9393
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
94-
tensor.backward(grad)
94+
torch.autograd.backward(
95+
tensors=tensor,
96+
grad_tensors=grad,
97+
inputs=inputs,
98+
retain_graph=retain_graph,
99+
)
95100

96101
def zero_grad(self, *args, **kwargs):
97102
for p in self.working_to_master_map.keys():

colossalai/booster/mixed_precision/fp16_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def __init__(
4646
growth_interval=growth_interval,
4747
)
4848

49-
def backward(self, loss: Tensor, *args, **kwargs) -> None:
49+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
5050
scaled_loss = self.scale_loss(loss)
51-
scaled_loss.backward(*args, **kwargs)
51+
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
5252

5353
def step(self, *args, **kwargs) -> Optional[float]:
5454
out = self.scaler.step(self.optim, *args, **kwargs)

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from colossalai.interface.optimizer import DistributedOptim
2929
from colossalai.logging import get_dist_logger
3030
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
31-
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
31+
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
3232
from colossalai.pipeline.stage_manager import PipelineStageManager
3333
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3434
from colossalai.quantization.fp8_hook import FP8Hook
@@ -295,7 +295,7 @@ def __init__(
295295
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
296296
super().__init__(optim)
297297

298-
def backward(self, loss: Tensor, *args, **kwargs):
298+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
299299
r"""
300300
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
301301
@@ -313,8 +313,12 @@ def backward(self, loss: Tensor, *args, **kwargs):
313313
"""
314314

315315
# Call the superclass backward method to compute gradients.
316+
<<<<<<< HEAD
316317
with self.model._hook_context():
317318
super().backward(loss, *args, **kwargs)
319+
=======
320+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
321+
>>>>>>> [plugin] hybrid support zero bubble pipeline (#6060)
318322

319323
if self.model.require_grad_sync:
320324
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -323,7 +327,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
323327
# If gradient synchronization is is not required, return.
324328
return
325329

326-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
330+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
327331
"""
328332
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
329333
@@ -340,7 +344,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
340344
"""
341345

342346
# Call the superclass backward method to compute gradients.
343-
super().backward_by_grad(tensor, grad)
347+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
344348

345349
if self.model.require_grad_sync:
346350
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -520,7 +524,7 @@ def __init__(
520524
max_norm=max_norm,
521525
)
522526

523-
def backward(self, loss: Tensor, *args, **kwargs):
527+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
524528
r"""
525529
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
526530
@@ -537,8 +541,12 @@ def backward(self, loss: Tensor, *args, **kwargs):
537541
None
538542
"""
539543
# Call the superclass backward method to compute gradients.
544+
<<<<<<< HEAD
540545
with self.model._hook_context():
541546
super().backward(loss, *args, **kwargs)
547+
=======
548+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
549+
>>>>>>> [plugin] hybrid support zero bubble pipeline (#6060)
542550

543551
if self.model.require_grad_sync:
544552
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -547,7 +555,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
547555
# If gradient synchronization is is not required, return.
548556
return
549557

550-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
558+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
551559
"""
552560
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
553561
@@ -563,7 +571,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
563571
None
564572
"""
565573
# Call the superclass backward method to compute gradients.
566-
super().backward_by_grad(tensor, grad)
574+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
567575

568576
if self.model.require_grad_sync:
569577
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -780,7 +788,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
780788
else:
781789
return
782790

783-
def backward(self, loss, retain_graph=False):
791+
def backward(self, loss, inputs=None, retain_graph=False):
784792
"""
785793
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
786794
@@ -796,7 +804,7 @@ def backward(self, loss, retain_graph=False):
796804
None
797805
"""
798806
# Call the superclass backward method to compute gradients.
799-
super().backward(loss, retain_graph)
807+
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
800808

801809
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
802810
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -805,7 +813,7 @@ def backward(self, loss, retain_graph=False):
805813
# If gradient synchronization is is not required, return.
806814
return
807815

808-
def backward_by_grad(self, tensor, grad):
816+
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
809817
"""
810818
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
811819
@@ -821,7 +829,7 @@ def backward_by_grad(self, tensor, grad):
821829
None
822830
"""
823831
# Call the superclass backward_by_grad method to compute gradients.
824-
super().backward_by_grad(tensor, grad)
832+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
825833

826834
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
827835
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -1026,6 +1034,7 @@ def __init__(
10261034
custom_policy: Policy = None,
10271035
pp_style: str = "1f1b",
10281036
num_model_chunks: int = 1,
1037+
scheduler_nodes: List = None,
10291038
num_layers_per_stage: Optional[List[int]] = None,
10301039
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
10311040
enable_metadata_cache: bool = True,
@@ -1044,6 +1053,9 @@ def __init__(
10441053
dist.get_world_size() % (tp_size * pp_size) == 0
10451054
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
10461055

1056+
assert (
1057+
not pp_style == "zbv" or scheduler_nodes is not None
1058+
), f"scheduler_nodes must not be None when using zero bubble pipeline."
10471059
if enable_sequence_parallelism:
10481060
self.sequence_parallelism_mode = (
10491061
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@@ -1105,29 +1117,39 @@ def __init__(
11051117
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
11061118

11071119
self.stage_manager = None
1108-
self.schedule = None
1120+
self.scheduler = None
11091121
self.custom_policy = custom_policy
11101122
assert zero_stage in (0, 1, 2)
11111123
if self.pp_size > 1:
1112-
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
1113-
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
1124+
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
1125+
assert (
1126+
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
1127+
), "num_model_chunks must be 1 when using 1f1b"
1128+
assert (
1129+
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
1130+
), "num_model_chunks must be 2 when using zero bubble pipeline"
11141131
assert (
11151132
num_microbatches is not None or microbatch_size is not None
11161133
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
11171134
assert (
11181135
self.zero_stage <= 1
11191136
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
1137+
if pp_style == "zbv":
1138+
self.logger.warning(
1139+
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
1140+
)
11201141
self.stage_manager = PipelineStageManager(
11211142
self.pg_mesh,
11221143
pipeline_axis=self.pp_axis,
1123-
enable_interleave=(pp_style == "interleaved"),
1144+
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
1145+
use_zbv=(pp_style == "zbv"),
11241146
num_model_chunks=num_model_chunks,
11251147
num_layers_per_stage=num_layers_per_stage,
11261148
)
11271149

11281150
if pp_style == "interleaved":
11291151
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
1130-
self.schedule = InterleavedSchedule(
1152+
self.scheduler = InterleavedSchedule(
11311153
stage_manager=self.stage_manager,
11321154
num_model_chunks=num_model_chunks,
11331155
num_microbatch=num_microbatches,
@@ -1137,13 +1159,21 @@ def __init__(
11371159
fp8_communication=fp8_communication,
11381160
)
11391161
elif pp_style == "1f1b":
1140-
self.schedule = OneForwardOneBackwardSchedule(
1162+
self.scheduler = OneForwardOneBackwardSchedule(
11411163
stage_manager=self.stage_manager,
11421164
num_microbatches=num_microbatches,
11431165
microbatch_size=microbatch_size,
11441166
enable_metadata_cache=enable_metadata_cache,
11451167
fp8_communication=fp8_communication,
11461168
)
1169+
elif pp_style == "zbv":
1170+
self.scheduler = ZeroBubbleVPipeScheduler(
1171+
stage_manager=self.stage_manager,
1172+
schedule=scheduler_nodes,
1173+
num_model_chunks=num_model_chunks,
1174+
num_microbatch=num_microbatches,
1175+
microbatch_size=microbatch_size,
1176+
)
11471177
else:
11481178
raise NotImplementedError()
11491179
if sequence_parallelism_mode == "ring_attn":
@@ -1257,7 +1287,6 @@ def configure(
12571287

12581288
# Replace with distributed implementation if exists
12591289
optimizer = cast_to_distributed(optimizer)
1260-
12611290
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
12621291
self.logger.warning(
12631292
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
@@ -1374,7 +1403,7 @@ def execute_pipeline(
13741403
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
13751404

13761405
with ctx, model._hook_context():
1377-
outputs = self.schedule.forward_backward_step(
1406+
outputs = self.scheduler.forward_backward_step(
13781407
model, data_iter, criterion, optimizer, return_loss, return_outputs
13791408
)
13801409

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __init__(
287287
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
288288

289289
self.stage_manager = None
290-
self.schedule = None
290+
self.scheduler = None
291291
self.custom_policy = custom_policy
292292
assert zero_stage in (0, 1, 2)
293293
if self.pp_size > 1:
@@ -311,7 +311,7 @@ def __init__(
311311

312312
if pp_style == "interleaved":
313313
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
314-
self.schedule = InterleavedSchedule(
314+
self.scheduler = InterleavedSchedule(
315315
stage_manager=self.stage_manager,
316316
num_model_chunks=num_model_chunks,
317317
num_microbatch=num_microbatches,
@@ -320,7 +320,7 @@ def __init__(
320320
overlap_p2p=overlap_p2p,
321321
)
322322
elif pp_style == "1f1b":
323-
self.schedule = OneForwardOneBackwardSchedule(
323+
self.scheduler = OneForwardOneBackwardSchedule(
324324
stage_manager=self.stage_manager,
325325
num_microbatches=num_microbatches,
326326
microbatch_size=microbatch_size,

colossalai/interface/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs):
4949
"""
5050
self.optim.zero_grad(*args, **kwargs)
5151

52-
def backward(self, loss: Tensor, *args, **kwargs):
52+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
5353
"""
5454
Performs a backward pass on the loss.
5555
"""
56-
loss.backward(*args, **kwargs)
56+
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
5757

5858
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
5959
"""

colossalai/pipeline/stage_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool:
136136
if not self.is_interleave or ignore_chunk:
137137
return self.stage == self.num_stages - 1
138138
else:
139-
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
139+
# use zero bubble pipeline
140+
if self.use_zbv:
141+
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
142+
else:
143+
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
140144

141145
@property
142146
def num_stages(self) -> int:

colossalai/shardformer/policies/llama.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]:
261261
held_layers.append(module.embed_tokens)
262262
for start_idx, end_idx in stage_indices:
263263
held_layers.extend(module.layers[start_idx:end_idx])
264-
if stage_manager.is_last_stage(ignore_chunk=True):
264+
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
265+
held_layers.append(module.norm)
266+
elif stage_manager.is_last_stage(ignore_chunk=True):
265267
held_layers.append(module.norm)
266268

267269
else:
@@ -353,7 +355,9 @@ def get_held_layers(self) -> List[Module]:
353355
"""Get pipeline layers for current stage."""
354356
stage_manager = self.pipeline_stage_manager
355357
held_layers = super().get_held_layers()
356-
if stage_manager.is_last_stage(ignore_chunk=True):
358+
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
359+
held_layers.append(self.model.lm_head)
360+
elif stage_manager.is_last_stage(ignore_chunk=True):
357361
held_layers.append(self.model.lm_head)
358362
return held_layers
359363

@@ -411,7 +415,9 @@ def get_held_layers(self) -> List[Module]:
411415
"""Get pipeline layers for current stage."""
412416
stage_manager = self.pipeline_stage_manager
413417
held_layers = super().get_held_layers()
414-
if stage_manager.is_last_stage(ignore_chunk=True):
418+
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
419+
held_layers.append(self.model.score)
420+
elif stage_manager.is_last_stage(ignore_chunk=True):
415421
held_layers.append(self.model.score)
416422
return held_layers
417423

0 commit comments

Comments
 (0)