[Fix] Fix raw_bs bug when using flashinfer mla and eagle (#4557)

This commit is contained in:
Baizhou Zhang
2025-03-18 21:26:53 -07:00
committed by GitHub
parent c0e9a36c5f
commit 90532b7627

View File

@@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
if self.enable_torch_compile:
set_torch_compile_config()
@@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
)
@@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
if forward_batch.decode_seq_lens_cpu is not None:
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
return out