Skip to content

[CB] Scheduling constraints regarding number of available blocks/pages #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 15, 2025

Conversation

yannicks1
Copy link
Collaborator

@yannicks1 yannicks1 commented Jun 24, 2025

[CB] Scheduling constraints regarding number of available blocks/pages

changes:

  • moved hard coded BLOCK_SIZE (64) variable to Platform class and import it where needed instead of defining it in multiple different places.
  • introduced scheduler constraints regarding number of available blocks/pages in can_schedule() (need to keep track of the reserved block ids per request in model_runner.reserved_blocks)
  • wrote unit test for new scheduler constraint asserting n_reserved_blocks and n_used_blocks
  • introduce env variable VLLM_SPYRE_N_BLOCKS to set the number of available blocks (needs to be done during initialzation of all classes, if someone knows a better way, please tell me) for the unit tests.
  • using math.ceil(n / d) instead of ((n + d - 1) // d) for better readability
  • renaming model_runner.free_pages to model_runner.block_pool (as it is called in upstream vLLM)

closes #260

Copy link

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

},
{
# Prefill sequence 0
# total blocks in use: 1
Copy link
Collaborator

@prashantgupta24 prashantgupta24 Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI

scheduler.n_free_blocks

should give us the number of free blocks at each step and should be assertable :)

We can technically get the exact blocks used by:

engine_core.model_executor.driver_worker.worker.model_runner.free_blocks

but that might be an overkill here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I think this is a nice additional check! Will add this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done: e5697bc

@yannicks1 yannicks1 marked this pull request as ready for review July 4, 2025 14:15
@sducouedic
Copy link
Collaborator

LGTM

Comment on lines 1208 to 1209
max_requested_blocks[req_id] = len(req_ids2blocks[req_id])
max_reserved_blocks[req_id] = reserved_blocks[req_id]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused why these two variables are prefixed with 'max'

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, was a relict of the past. I just updated the variable names!

@@ -74,6 +75,11 @@ def _backend_backwards_compat() -> str:
"VLLM_SPYRE_USE_CB":
lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))),

# If set, use the V1 continuous batching implementation. Otherwise, static
# batching mode will be enabled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment is correct? From the code it looks like this is an override for the number of kv cache blocks which is used in place of the simple max_model_len * max_num_seqs calculation.

I'm assuming this is for passing in the known good values for the available blocks given a tested model and card combo. Can we also consider using the existing kv cache override instead of setting up a new one? There's --num-gpu-blocks-override which will set the available blocks in the scheduler config

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the comment is from copy pasting the code from VLLM_SPYRE_USE_CB 😆
good pointer with --num-gpu-blocks-override, will look into that!

# overwrite n_blocks_avail for testing scheduler constraints
if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0:
n_blocks_avail = envs_spyre.VLLM_SPYRE_N_BLOCKS
model_runner._set_blocks(num_blocks=n_blocks_avail)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm seeing here is that:

  • self._get_num_blocks_available() only uses info from self.model_runner
  • The resulting n_blocks_avail is only used to modify the model runner, and the model runner's model

So I think that _get_num_blocks_available needs to move to the model runner class, and it should be responsible for finalizing itself after the warmup is complete

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although this was not part of this PR, you are absolutely right and I moved the function 😇

Copy link
Collaborator

@joerunde joerunde Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah what I meant here is that the model runner should encapsulate this entirely. This whole block can be replaced with model_runner.finish_warmup(), and there should be no access of the model runner's private methods or direct access of the model here in the worker

"""Function returns the number of available blocks/pages.
Will eventually contain a function in torch_sendnn which reads
the actual value provided by the compiler for backend sendnn"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate some of the if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0: checks into here and return the override from this method?

@yannicks1 yannicks1 merged commit 009c7a5 into main Jul 15, 2025
17 of 18 checks passed
@yannicks1 yannicks1 deleted the ysc-max_blocks-scheduler-constraint branch July 15, 2025 15:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[CB] Scheduling constraints regarding number of available blocks/pages
4 participants