[Revision] Add fast decode plan for flashinfer mla (#4012)

This commit is contained in:
Baizhou Zhang
2025-03-05 11:20:41 -08:00
committed by GitHub
parent 71ab0dabe0
commit fc91d08a8f
9 changed files with 145 additions and 34 deletions

View File

@@ -200,6 +200,9 @@ class CudaGraphRunner:
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
if self.enable_torch_compile:
set_torch_compile_config()
@@ -448,6 +451,10 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
@@ -466,6 +473,7 @@ class CudaGraphRunner:
self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
)
# Replay

View File

@@ -156,6 +156,9 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
@@ -280,6 +283,8 @@ class ForwardBatch:
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32