Skip to content

Commit

Permalink
Support force free memory for policy model with no colocate.
Browse files Browse the repository at this point in the history
  • Loading branch information
adoda committed Jan 26, 2025
1 parent 6562ee3 commit fbd3159
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions chatlearn/models/torch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _get_if_not_none(self, to_set, default):
return default

def onload(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None):
if not self.is_colocate:
if not (self.is_colocate or self.module_args.force_free_memory):
return
to_onload_weights = self._get_if_not_none(to_onload_weights, self.module_args.offload_weights)
to_build_grad_buffers = self._get_if_not_none(to_build_grad_buffers, self.module_args.free_grad_buffers)
Expand Down Expand Up @@ -175,7 +175,7 @@ def offload(self, to_offload_weights=None, to_free_grad_buffers=None, to_offload
# The first time of calling `offload_weights` and `offload_main_weights` has a higher peak memory.
# So `free_grad_buffers` is called first to free memory, and `offload_weights` is called afterward
# to make more space for `offload_main_weights`.
if not self.is_colocate:
if not (self.is_colocate or self.module_args.force_free_memory):
return
to_offload_weights = self._get_if_not_none(to_offload_weights, self.module_args.offload_weights)
to_offload_main_weights = self._get_if_not_none(to_offload_main_weights, self.module_args.offload_weights)
Expand Down
6 changes: 3 additions & 3 deletions chatlearn/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def compute_loop_one_model(self, model_node, num_batch=None):
results = []
self.timers(f"{model.name}").start()
for step in range(num_batch):
to_empty_cache = step >= last_step_start and model.is_colocate
to_onload = step < replica_num and model.is_colocate and model.enable_offload
to_offload = step >= last_step_start and model.is_colocate and model.enable_offload
to_empty_cache = step >= last_step_start and (model.is_colocate or model.module_args.force_free_memory)
to_onload = step < replica_num and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
to_offload = step >= last_step_start and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
replica = self._next_model(model)
_, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache,
is_eval=is_eval, to_onload=to_onload, to_offload=to_offload)
Expand Down
2 changes: 2 additions & 0 deletions chatlearn/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class ModelConfig(BaseConfig):
free_grad_buffers = False
#: overall switch for offload optimizer states/weights and free grad buffers
free_memory = False
#: force to free memory
force_free_memory = False

def __init__(self):
super().__init__()
Expand Down
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/grpo_math_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ models:
ranking: ${batch_generation_ranking:False}
min_prompt_length: ${batch_generation_min_prompt_length:0}
free_memory: ${free_memory_policy:False}
force_free_memory: ${force_free_memory_policy:False}

reference:
model_config_file: reference.yaml
Expand Down
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/online_dpo_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ models:
ranking: ${batch_generation_ranking:False}
min_prompt_length: ${batch_generation_min_prompt_length:0}
free_memory: ${free_memory_policy:False}
force_free_memory: ${force_free_memory_policy:False}

reference:
model_config_file: reference.yaml
Expand Down
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/rlhf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ models:
ranking: ${batch_generation_ranking:False}
min_prompt_length: ${batch_generation_min_prompt_length:0}
free_memory: ${free_memory_policy:False}
force_free_memory: ${force_free_memory_policy:False}

reference:
model_config_file: reference.yaml
Expand Down
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/vllm_rlhf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ models:
ranking: ${batch_generation_ranking:False}
min_prompt_length: ${batch_generation_min_prompt_length:0}
free_memory: ${free_memory_policy:False}
force_free_memory: ${force_free_memory_policy:False}

reference:
model_config_file: reference.yaml
Expand Down

0 comments on commit fbd3159

Please sign in to comment.