Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_indptr = [kv_indptr_buf]
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
if kv_last_page_len_buf is None:
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_last_page_len = kv_last_page_len_buf
|
||||
|
||||
self.qo_indptr = [
|
||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||
for _ in range(self.num_wrappers)
|
||||
@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
kv_last_page_len_buf=self.kv_last_page_len,
|
||||
)
|
||||
)
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
|
||||
Reference in New Issue
Block a user