Fix filter_batch function call (#1681)

This commit is contained in:
Liangsheng Yin
2024-10-15 22:59:26 -07:00
committed by GitHub
parent f1088e0fc8
commit b6b4094621

View File

@@ -649,7 +649,7 @@ class ScheduleBatch:
req.last_update_decode_tokens = 0
req.logprob_start_len = 10**9
self.filter_batch(sorted_indices)
self.filter_batch(keep_indices=sorted_indices)
# Reqs in batch are filtered
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)