-
Notifications
You must be signed in to change notification settings - Fork 47
[Feat] toploc2 #360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] toploc2 #360
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new Toploc2Sampler for logit sampling, propagates per-sample seeds through the pipeline into Parquet outputs, and adjusts related tests and utilities to include the seed field.
- Add
Toploc2Sampler
and switch to it when appropriate, disabling chunked prefill for correctness - Introduce and propagate a
seed
field in configs, inference logic, Parquet schema, and tests - Add a model validator to convert negative
logprobs
config values toNone
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
File | Description |
---|---|
tests/integration/inference/test_debug_pp.py | Standardize model name arguments |
tests/conftest.py | Add dummy seed column to test tables |
src/zeroband/utils/parquet.py | Extend Parquet schema with seed field |
src/zeroband/inference/toploc2.py | Implement the new Toploc2 sampling layer |
src/zeroband/inference/parquet.py | Pass seed values into Parquet records |
src/zeroband/inference/config.py | Validate and normalize logprobs setting |
src/zeroband/infer.py | Wire up Toploc2Sampler and seed logic |
Comments suppressed due to low confidence (2)
src/zeroband/inference/toploc2.py:124
- This
TODO
should be resolved or removed; if rank information is required for logprob metadata, implement a clear method to retrieve it and document the approach.
# TODO: How did the original code know the rank?
src/zeroband/inference/config.py:22
- [nitpick] The method name and docstring refer to "negative logprobs," but the field is a count of logprobs; consider renaming to
convert_negative_logprobs_count_to_none
or clarifying thatlogprobs < 0
disables computation.
def convert_negative_logprobs_to_none(self):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hell yea! nice job. have we tested that the produced output tokens are identical for the sampler and toploc sampler, ie. are we absolutely sure we are not altering model behavior? could be nice to have a simple test for this?
also, i think adding toploc 2 into the configs is also important. let's get this merged soon, so that i can rebase onto the config refactor:)
@model_validator(mode="after") | ||
def convert_negative_logprobs_to_none(self): | ||
"""Convert negative logprobs values to None to disable logprobs calculation.""" | ||
if self.logprobs is not None and self.logprobs < 0: | ||
self.logprobs = None | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels more intuitive to err when passing negative values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is necessary to disable logprobs. I couldnt find a way to pass none and since the default is 0, it didnt seem like there was a way to make it None other than this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lfgtm v2
No description provided.