[Revision] Add fast decode plan for flashinfer mla (#4012)
This commit is contained in:
@@ -45,6 +45,7 @@ class AttentionBackend(ABC):
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -353,6 +353,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
@@ -1058,6 +1059,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
@@ -10,6 +10,7 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -62,6 +63,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
@@ -84,10 +86,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.q_indptr_decode = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
@@ -125,6 +123,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
decode_wrapper=self.decode_wrapper,
|
||||
init_metadata_replay=False,
|
||||
)
|
||||
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
||||
else:
|
||||
@@ -160,13 +159,20 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
||||
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
||||
self.cuda_graph_kv_lens = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
|
||||
self.cuda_graph_qo_indptr = self.kv_indptr.clone()
|
||||
|
||||
# For fast decode plan in graph replaying
|
||||
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
|
||||
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
|
||||
self.fast_decode_kwargs = {
|
||||
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
|
||||
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
|
||||
"kv_indices": self.cuda_graph_kv_indices,
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
@@ -182,10 +188,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=self.qo_indptr[: num_tokens + 1],
|
||||
kv_indptr=self.kv_indptr[: num_tokens + 1],
|
||||
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
|
||||
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
|
||||
kv_indices=self.cuda_graph_kv_indices,
|
||||
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
||||
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
||||
backend="auto",
|
||||
)
|
||||
|
||||
@@ -195,9 +201,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrapper=decode_wrapper,
|
||||
init_metadata_replay=False,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
||||
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
||||
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
@@ -210,13 +218,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
||||
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
||||
kv_len_arr_cpu, dim=0
|
||||
)
|
||||
self.fast_decode_kwargs.update(
|
||||
{
|
||||
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
|
||||
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
|
||||
"kv_len_arr_cpu": kv_len_arr_cpu,
|
||||
}
|
||||
)
|
||||
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
||||
init_metadata_replay=True,
|
||||
**self.fast_decode_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
@@ -316,7 +339,6 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.q_indptr = attn_backend.q_indptr_decode
|
||||
|
||||
@@ -326,6 +348,8 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
||||
init_metadata_replay: bool = False,
|
||||
**fast_decode_kwargs,
|
||||
):
|
||||
decode_wrapper = decode_wrapper or self.decode_wrapper
|
||||
self.call_begin_forward(
|
||||
@@ -335,6 +359,8 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
seq_lens_sum,
|
||||
self.q_indptr,
|
||||
self.kv_indptr,
|
||||
init_metadata_replay,
|
||||
**fast_decode_kwargs,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
@@ -345,14 +371,19 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
q_indptr: torch.Tensor,
|
||||
kv_indptr: torch.Tensor,
|
||||
init_metadata_replay: bool = False,
|
||||
**fast_decode_kwargs,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
q_indptr = q_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
kv_indices = (
|
||||
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
||||
if not init_metadata_replay
|
||||
else fast_decode_kwargs["kv_indices"]
|
||||
)
|
||||
|
||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||
sm_scale = self.scaling
|
||||
|
||||
@@ -365,21 +396,36 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_lens,
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
)
|
||||
if not init_metadata_replay:
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_lens,
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
)
|
||||
else:
|
||||
wrapper.plan(
|
||||
fast_decode_kwargs["qo_indptr_cpu"],
|
||||
fast_decode_kwargs["kv_indptr_cpu"],
|
||||
kv_indices,
|
||||
fast_decode_kwargs["kv_len_arr_cpu"],
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAIndicesUpdaterPrefill:
|
||||
@@ -399,7 +445,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
@@ -496,3 +541,42 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
self.q_data_type,
|
||||
self.data_type,
|
||||
)
|
||||
|
||||
|
||||
def fast_mla_decode_plan(
|
||||
self,
|
||||
qo_indptr_cpu: torch.Tensor,
|
||||
kv_indptr_cpu: torch.Tensor,
|
||||
kv_indices: torch.Tensor,
|
||||
kv_len_arr_cpu: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_dim_ckv: int,
|
||||
head_dim_kpe: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
sm_scale: float,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
) -> None:
|
||||
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
|
||||
for skipping the stream synchronization in original plan function during
|
||||
cuda graph replaying.
|
||||
"""
|
||||
self._causal = causal
|
||||
self._page_size = page_size
|
||||
self._sm_scale = sm_scale
|
||||
|
||||
with self.device as device:
|
||||
stream = torch.cuda.current_stream(device).cuda_stream
|
||||
self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_cpu,
|
||||
kv_indptr_cpu,
|
||||
kv_len_arr_cpu,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
causal,
|
||||
stream,
|
||||
)
|
||||
|
||||
@@ -312,6 +312,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -587,6 +588,7 @@ class TritonMultiStepDraftBackend:
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
@@ -1187,8 +1187,13 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
decode_seq_lens = self.seq_lens.cpu()
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
@@ -1215,6 +1220,7 @@ class ScheduleBatch:
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
decode_seq_lens=decode_seq_lens,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
@@ -1291,6 +1297,9 @@ class ModelWorkerBatch:
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For decode
|
||||
decode_seq_lens: Optional[torch.Tensor]
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
|
||||
@@ -200,6 +200,9 @@ class CudaGraphRunner:
|
||||
)
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
@@ -448,6 +451,10 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
if forward_batch.decode_seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
@@ -466,6 +473,7 @@ class CudaGraphRunner:
|
||||
self.encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
)
|
||||
|
||||
# Replay
|
||||
|
||||
@@ -156,6 +156,9 @@ class ForwardBatch:
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For decode
|
||||
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
@@ -280,6 +283,8 @@ class ForwardBatch:
|
||||
if ret.forward_mode.is_decode():
|
||||
if ret.positions is None:
|
||||
ret.positions = clamp_position(batch.seq_lens)
|
||||
if ret.decode_seq_lens_cpu is None:
|
||||
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
||||
else:
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
|
||||
Reference in New Issue
Block a user