Revert "Add fast decode plan for flashinfer mla" (#4008)

This commit is contained in:
Lianmin Zheng
2025-03-02 19:29:10 -08:00
committed by GitHub
parent fa56106731
commit 9e1014cf99
9 changed files with 52 additions and 156 deletions

View File

@@ -199,10 +199,6 @@ class CudaGraphRunner:
if self.enable_torch_compile:
set_torch_compile_config()
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -373,9 +369,9 @@ class CudaGraphRunner:
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_batch.forward_mode,
encoder_lens=encoder_lens,
spec_info=forward_batch.spec_info,
forward_batch.spec_info,
)
# Run and capture
@@ -438,7 +434,6 @@ class CudaGraphRunner:
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
self.seq_lens_cpu.fill_(1)
# Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
@@ -446,8 +441,6 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
@@ -463,10 +456,9 @@ class CudaGraphRunner:
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
forward_batch.forward_mode,
encoder_lens=self.encoder_lens,
spec_info=forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
forward_batch.spec_info,
)
# Replay

View File

@@ -152,9 +152,6 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
@@ -259,8 +256,6 @@ class ForwardBatch:
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32