[Fix] Fix raw_bs bug when using flashinfer mla and eagle (#4557)
This commit is contained in:
@@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
||||
0
|
||||
].get_cuda_graph_seq_len_fill_value()
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
@@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||
forward_batch.positions = self.positions[:num_tokens]
|
||||
|
||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
|
||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
forward_batch, bs
|
||||
)
|
||||
@@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.positions = self.positions[:raw_num_token]
|
||||
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
||||
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
||||
if forward_batch.decode_seq_lens_cpu is not None:
|
||||
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user