Fix draft decode max batch size (#3676)

This commit is contained in:
Ke Bao
2025-02-18 23:03:26 +08:00
committed by GitHub
parent f983213a1f
commit e5ce395a6c
3 changed files with 5 additions and 2 deletions

View File

@@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend:
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,