[Fix] Fix raw_bs bug when using flashinfer mla and eagle (#4557)
This commit is contained in:
@@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
||||||
0
|
0
|
||||||
].get_cuda_graph_seq_len_fill_value()
|
].get_cuda_graph_seq_len_fill_value()
|
||||||
|
self.seq_lens_cpu = torch.full(
|
||||||
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_torch_compile:
|
if self.enable_torch_compile:
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
@@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||||
forward_batch.positions = self.positions[:num_tokens]
|
forward_batch.positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||||
|
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
|
||||||
|
self.seq_lens_cpu.fill_(1)
|
||||||
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||||
|
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||||
|
|
||||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
forward_batch, bs
|
forward_batch, bs
|
||||||
)
|
)
|
||||||
@@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
forward_batch.positions = self.positions[:raw_num_token]
|
forward_batch.positions = self.positions[:raw_num_token]
|
||||||
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
||||||
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
||||||
|
if forward_batch.decode_seq_lens_cpu is not None:
|
||||||
|
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user