Fix draft decode max batch size (#3676)
This commit is contained in:
@@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(
|
(
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
|
|||||||
@@ -474,7 +474,7 @@ class TritonMultiStepDraftBackend:
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(
|
(
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
|
|||||||
@@ -635,6 +635,9 @@ def decode_attention_fwd(
|
|||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
):
|
):
|
||||||
assert num_kv_splits == attn_logits.shape[2]
|
assert num_kv_splits == attn_logits.shape[2]
|
||||||
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||||
|
assert q.shape[0] <= attn_logits.shape[0]
|
||||||
|
|
||||||
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
||||||
|
|
||||||
if kv_group_num == 1:
|
if kv_group_num == 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user