Revert "Add fast decode plan for flashinfer mla" (#4008)

This commit is contained in:
Lianmin Zheng
2025-03-02 19:29:10 -08:00
committed by GitHub
parent fa56106731
commit 9e1014cf99
9 changed files with 52 additions and 156 deletions

View File

@@ -29,8 +29,9 @@ class AttentionBackend(ABC):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
**kwargs,
spec_info: Optional[SpecInfo],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
@@ -41,8 +42,9 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
**kwargs,
spec_info: Optional[SpecInfo],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()

View File

@@ -269,10 +269,9 @@ class FlashInferAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
@@ -340,10 +339,9 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(

View File

@@ -10,7 +10,6 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Optional, Union
import torch
@@ -28,12 +27,14 @@ from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
@dataclass
@@ -62,7 +63,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
global_config.enable_flashinfer_mla = True
@@ -85,6 +85,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
@@ -122,7 +126,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else:
@@ -158,20 +161,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
(max_bs,), dtype=torch.int32, device=self.device
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
# For fast decode plan in graph replaying
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
self.cuda_graph_qo_indptr = self.kv_indptr.clone()
def init_forward_metadata_capture_cuda_graph(
self,
@@ -179,17 +175,18 @@ class FlashInferMLAAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
**kwargs,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
qo_indptr=self.qo_indptr[: num_tokens + 1],
kv_indptr=self.kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
kv_len_arr=self.kv_last_page_len[:num_tokens],
backend="auto",
)
@@ -199,11 +196,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
@@ -213,30 +208,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
seq_lens_cpu: torch.Tensor,
**kwargs,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
**self.fast_decode_kwargs,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
@@ -336,6 +317,7 @@ class FlashInferMLAIndicesUpdaterDecode:
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
@@ -345,8 +327,6 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
@@ -356,8 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
**fast_decode_kwargs,
)
def call_begin_forward(
@@ -368,19 +346,14 @@ class FlashInferMLAIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
@@ -393,36 +366,21 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_indices,
self.req_to_token.shape[1],
)
if not init_metadata_replay:
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
class FlashInferMLAIndicesUpdaterPrefill:
@@ -442,6 +400,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
@@ -538,42 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.q_data_type,
self.data_type,
)
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale
with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)

View File

@@ -230,10 +230,9 @@ class TritonAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
assert encoder_lens is None, "Not supported"
@@ -309,10 +308,9 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():

View File

@@ -582,9 +582,6 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
# For decode
decode_seq_lens: List[int] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
@@ -1171,10 +1168,8 @@ class ScheduleBatch:
def get_model_worker_batch(self):
if self.forward_mode.is_decode_or_idle():
decode_seq_lens = self.seq_lens.cpu()
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
@@ -1199,7 +1194,6 @@ class ScheduleBatch:
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
@@ -1273,9 +1267,6 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For decode
decode_seq_lens: Optional[torch.Tensor]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]

View File

@@ -199,10 +199,6 @@ class CudaGraphRunner:
if self.enable_torch_compile:
set_torch_compile_config()
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -373,9 +369,9 @@ class CudaGraphRunner:
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_batch.forward_mode,
encoder_lens=encoder_lens,
spec_info=forward_batch.spec_info,
forward_batch.spec_info,
)
# Run and capture
@@ -438,7 +434,6 @@ class CudaGraphRunner:
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
self.seq_lens_cpu.fill_(1)
# Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
@@ -446,8 +441,6 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
@@ -463,10 +456,9 @@ class CudaGraphRunner:
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
forward_batch.forward_mode,
encoder_lens=self.encoder_lens,
spec_info=forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
forward_batch.spec_info,
)
# Replay

View File

@@ -152,9 +152,6 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
@@ -259,8 +256,6 @@ class ForwardBatch:
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32