Add fast decode plan for flashinfer mla (#3987)
This commit is contained in:
@@ -199,6 +199,10 @@ class CudaGraphRunner:
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Graph inputs
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
@@ -369,9 +373,9 @@ class CudaGraphRunner:
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
@@ -434,6 +438,7 @@ class CudaGraphRunner:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
||||
@@ -441,6 +446,8 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
if forward_batch.decode_seq_lens_cpu is not None:
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
@@ -456,9 +463,10 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
self.encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
encoder_lens=self.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
)
|
||||
|
||||
# Replay
|
||||
|
||||
@@ -152,6 +152,9 @@ class ForwardBatch:
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For decode
|
||||
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
@@ -256,6 +259,8 @@ class ForwardBatch:
|
||||
if ret.forward_mode.is_decode():
|
||||
if ret.positions is None:
|
||||
ret.positions = clamp_position(batch.seq_lens)
|
||||
if ret.decode_seq_lens_cpu is None:
|
||||
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
||||
else:
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
|
||||
Reference in New Issue
Block a user