-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds batched inference with left-padding #886
base: main
Are you sure you want to change the base?
Conversation
I think the tests might need some tweaking (otherwise i might have broken them with a few to many pushes :D) Will leave this as is for review, and will be back ASAP to carry out any fixes that might be desirable or necessary. |
Hi @FlimFlamm! Thanks for working on this. I had this partially implemented but never pushed it and I might have lost it because I cannot find it in my stashes 💀. I'll sit on this for a bit and perhaps merge what you have with what I had. This will need tests and some performance benchmarking before landing. |
Happy to be of help! For a relatively small change, it definitely affects computation in a lot of scripts (anything that touches generate), which includes things like
I'm gonna tinker and test some more (hopefully to see if right padding an be more easily cinched in in case that turns out to be important for model performance) |
Pushed some additions and changes that seemed sensical or cleaner. Made a simple padding function for utils (can do left and right padding), and set up the mask cache to optionally take a padding mask (can be passed when the kv cache is being set, or directly to build_mask_cache). Also set up the same logic in sequentially.py for testing (seems to work great). Finally I also added optional attention masking for the forward pass of model.py's GPT class (which isn't required, but seems like it would be useful for anyone using special masking) NOTE: the masking strategy that bakes a batches left/right padding into the mask cache results in the mask cache being increased by a factor of batch_size (since we need unique padding inside each sequence's mask), but by doing so we dont have to do any tensor work during generation. In theory if the max sequence length explodes, this strategy loses its edge (because it quadratically scales the auto-regressive mask itself), which might make the batch_size factor start to hurt. |
Thanks for working on this @FlimFlamm, I was working on this functionality on my fork as well but the kv cache issue is a tricky one. I cloned your repo and tried to run generation on stablelm and TinyLlama, but both produced outputs that were jibberish. I didn't make any changes to your code, any idea what could be going on? |
Can I ask exactly what method or CLI arg you used to test? Will try to reproduce and see if i can find the issue. |
I just did the following: python scripts/download.py --repo_id 'TinyLlama/TinyLlama-1.1B-Chat-v1.0' --from_safetensors 1
python scripts/convert_hf_checkpoint.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'
python generate/base.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0' I set prompts = ["what food do llamas eat?"] and I get outputs that keep repeating words. |
Awesome, thanks for the details. Editing... |
So I found the problem; I was building my mask incorrectly. recent push should have the replacement build_mask_cache() function. Only other necessary change was to use right padding instead of left padding, because having corrected the mask i started re-encountering the NaN issue described here pytorch/pytorch#103749 The exact cause of the problem are cases where an entire line of the causal attention mask is "False", which screws with the dot product attention. The fix is apparently common in a lot of repos. Ours would be something like: def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.head_size)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=(1.0 - mask.to(dtype=q.dtype)) * -10000.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2) Instead of doing that, I just switched base.py to right padding, but the more i test, the more it looks like the above is a correct way to address the all False problem. Another legitimate fix for this particular model seems to just be using the padding token that is assigned, and not using a padding mask at all. Whether or not the model itself defines a padding token might be an indicator that no extra masking is required for the padding... Do let me know if the last push makes batched inference work on your end! |
Applying the large-negative number fix seems to have done the trick; left and right padding now are both equivalent in terms of output for the tinyllama chat modell. |
Running benchmarks on TinyLlama using the original generate code vs your batch implemetation yields (almost) identical scores now. Well done! Update |
Very interesting. A few questions/requests that might help me replicate/track this issue down:
I also wonder if the large-negative-number fix might not be ideally implemented here; im only using negative 10k (most implementations used Possibly this performance hit is a consequence of batched inference in and of itself? Can't find much about it but maybe? |
This branch and the upstream are starting to diverge, I'm going to copy your changes into my fork that's up-to-date and continue to dig around. What are you getting on your end, is doing 10 prompts in a batch the same as the same 10 prompts one at a time? |
@WilliamGazeley Thanks for the effort on this!
Batched inference does give different outputs for each sequence, which I think is by design. The good news is that the first sequence in the batch is the same as our single unbatched case along with original generate/mask code.
I agree, although assuming there is some small unavoidable performance loss in batched inference cases, I was thinking that an input being more out of its training distribution could amplify the performance degradation. (since at 1B this model is relatively brittle, perhaps that also magnifies the issues we're seeing re: performance) I'll keep poking at it as well to see what I can come up with (will fire up hellaswag soon as i can top replicate your findings and start hunting from there). |
Playing around further, I noticed that there's a huge difference in outputs if you change between bf16 and 16-true. This is somewhat expected I guess, but the batched 16-true is closer to the single bf16 than the batched bf16 is to the single bf16 - this is only on my benchmark though. Also, I think your implementation of def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.head_size)
y = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None if mask is None else (1.0 - mask.to(dtype=q.dtype)) * -10000.0,
dropout_p=0.0,
scale=scale,
is_causal=mask is None
)
return y.transpose(1, 2) |
Adds a left-padding batched inference strategy by modifying
generate/base.py
andmodel.py
prompt
toprompts
ingenerate.py
'smain()
function; it's still compatible with a string, and now with a list of strings2024-01-17.19-35-56.mp4
EDIT: currently triage'ing the test fails
Constructive feedback is very welcome. If something about this commit would adversely affect other parts of the repo that I have overlooked, I'll do my best to address it.