[Fix] Fix eagle with disable cuda graph (#3411)
This commit is contained in:
@@ -924,38 +924,50 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
||||
|
||||
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
|
||||
def common_template(
|
||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||
):
|
||||
num_seqs = forward_batch.batch_size
|
||||
bs = self.topk * num_seqs
|
||||
seq_lens_sum = forward_batch.seq_lens_sum
|
||||
|
||||
self.generate_draft_decode_kv_indices[
|
||||
(self.speculative_num_steps, num_seqs, self.topk)
|
||||
](
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.seq_lens,
|
||||
self.cuda_graph_kv_indices,
|
||||
kv_indices_buffer,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
self.kv_indptr_stride,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
]
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
kv_indices = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
@@ -965,7 +977,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
)
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
@@ -973,7 +985,6 @@ class FlashInferMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
@@ -995,7 +1006,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
][0]
|
||||
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
||||
def call_fn(i, forward_batch):
|
||||
@@ -1009,7 +1020,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user