Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
if kv_last_page_len_buf is None:
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
else:
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf
self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32,
device=model_runner.device,
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=self.kv_last_page_len,
)
)
self.max_context_len = self.attn_backends[0].max_context_len