Skip to content

Commit a7c5007

Browse files
authored
[Fix] Skip empty batch (#747)
1 parent d338635 commit a7c5007

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/turbomind/models/llama/LlamaBatch.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,10 @@ bool LlamaBatch<T>::Initialize()
475475
template<typename T>
476476
void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc)
477477
{
478+
if (desc.empty()) {
479+
return;
480+
}
481+
478482
std::vector<int> idxs(desc.size());
479483
std::iota(idxs.begin(), idxs.end(), 0);
480484

@@ -1430,18 +1434,21 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
14301434
// finished sequences is handled by `Initialize()`
14311435
finished_count = 0;
14321436

1433-
ContextDecode();
1434-
14351437
if (state_->active_size) {
1438+
1439+
ContextDecode();
1440+
14361441
if (modified) {
14371442
g = InitializeGeneration();
14381443
InitializeSampling();
14391444
}
1445+
14401446
for (int i = 0; i < step_length_; ++i) {
14411447
if (!Generate(g)) {
14421448
break;
14431449
}
14441450
}
1451+
14451452
if (auto signals = Finish(g, finished_count); !signals.empty()) {
14461453
if (finished_count) {
14471454
// Finished requests and corresponding output tensors will be released when notified

0 commit comments

Comments
 (0)