Reduce one step decode for draft model. (#11561)
This commit is contained in:
@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend:
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends.append(
|
||||
AiterAttnBackend(
|
||||
model_runner,
|
||||
@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend:
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends.append(
|
||||
FlashAttentionBackend(
|
||||
model_runner,
|
||||
@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
|
||||
@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
self.attn_backends: List[FlashInferAttnBackend] = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends.append(
|
||||
FlashInferAttnBackend(
|
||||
model_runner,
|
||||
@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends.append(
|
||||
FlashInferMLAAttnBackend(
|
||||
model_runner,
|
||||
@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends.append(
|
||||
FlashMLABackend(
|
||||
model_runner,
|
||||
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
|
||||
self.common_template(forward_batch, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, max_num_tokens, block_kv_indices=None
|
||||
)
|
||||
|
||||
@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend:
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.attn_backends: List[TritonAttnBackend] = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends.append(
|
||||
TritonAttnBackend(
|
||||
model_runner,
|
||||
@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend:
|
||||
if call_fn is None:
|
||||
return
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs,
|
||||
max_num_tokens,
|
||||
|
||||
@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
||||
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
||||
):
|
||||
super().__init__(model_runner, topk, speculative_num_steps)
|
||||
for i in range(speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i] = TRTLLMHAAttnBackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
|
||||
@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||
):
|
||||
super().__init__(model_runner, topk, speculative_num_steps)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i] = TRTLLMMLABackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
|
||||
Reference in New Issue
Block a user