Skip to content

Commit fbd3159

Browse files
committed
Support force free memory for policy model with no colocate.
1 parent 6562ee3 commit fbd3159

File tree

7 files changed

+11
-5
lines changed

7 files changed

+11
-5
lines changed

chatlearn/models/torch_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _get_if_not_none(self, to_set, default):
142142
return default
143143

144144
def onload(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None):
145-
if not self.is_colocate:
145+
if not (self.is_colocate or self.module_args.force_free_memory):
146146
return
147147
to_onload_weights = self._get_if_not_none(to_onload_weights, self.module_args.offload_weights)
148148
to_build_grad_buffers = self._get_if_not_none(to_build_grad_buffers, self.module_args.free_grad_buffers)
@@ -175,7 +175,7 @@ def offload(self, to_offload_weights=None, to_free_grad_buffers=None, to_offload
175175
# The first time of calling `offload_weights` and `offload_main_weights` has a higher peak memory.
176176
# So `free_grad_buffers` is called first to free memory, and `offload_weights` is called afterward
177177
# to make more space for `offload_main_weights`.
178-
if not self.is_colocate:
178+
if not (self.is_colocate or self.module_args.force_free_memory):
179179
return
180180
to_offload_weights = self._get_if_not_none(to_offload_weights, self.module_args.offload_weights)
181181
to_offload_main_weights = self._get_if_not_none(to_offload_main_weights, self.module_args.offload_weights)

chatlearn/runtime/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,9 @@ def compute_loop_one_model(self, model_node, num_batch=None):
298298
results = []
299299
self.timers(f"{model.name}").start()
300300
for step in range(num_batch):
301-
to_empty_cache = step >= last_step_start and model.is_colocate
302-
to_onload = step < replica_num and model.is_colocate and model.enable_offload
303-
to_offload = step >= last_step_start and model.is_colocate and model.enable_offload
301+
to_empty_cache = step >= last_step_start and (model.is_colocate or model.module_args.force_free_memory)
302+
to_onload = step < replica_num and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
303+
to_offload = step >= last_step_start and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
304304
replica = self._next_model(model)
305305
_, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache,
306306
is_eval=is_eval, to_onload=to_onload, to_offload=to_offload)

chatlearn/utils/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class ModelConfig(BaseConfig):
234234
free_grad_buffers = False
235235
#: overall switch for offload optimizer states/weights and free grad buffers
236236
free_memory = False
237+
#: force to free memory
238+
force_free_memory = False
237239

238240
def __init__(self):
239241
super().__init__()

examples/megatron/configs/llama2/grpo_math_vllm.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ models:
1616
ranking: ${batch_generation_ranking:False}
1717
min_prompt_length: ${batch_generation_min_prompt_length:0}
1818
free_memory: ${free_memory_policy:False}
19+
force_free_memory: ${force_free_memory_policy:False}
1920

2021
reference:
2122
model_config_file: reference.yaml

examples/megatron/configs/llama2/online_dpo_vllm.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ models:
1616
ranking: ${batch_generation_ranking:False}
1717
min_prompt_length: ${batch_generation_min_prompt_length:0}
1818
free_memory: ${free_memory_policy:False}
19+
force_free_memory: ${force_free_memory_policy:False}
1920

2021
reference:
2122
model_config_file: reference.yaml

examples/megatron/configs/llama2/rlhf.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ models:
1616
ranking: ${batch_generation_ranking:False}
1717
min_prompt_length: ${batch_generation_min_prompt_length:0}
1818
free_memory: ${free_memory_policy:False}
19+
force_free_memory: ${force_free_memory_policy:False}
1920

2021
reference:
2122
model_config_file: reference.yaml

examples/megatron/configs/llama2/vllm_rlhf.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ models:
1616
ranking: ${batch_generation_ranking:False}
1717
min_prompt_length: ${batch_generation_min_prompt_length:0}
1818
free_memory: ${free_memory_policy:False}
19+
force_free_memory: ${force_free_memory_policy:False}
1920

2021
reference:
2122
model_config_file: reference.yaml

0 commit comments

Comments
 (0)