Support FlashMLA backend (#4472)
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
128
python/sglang/srt/layers/attention/flashmla_backend.py
Normal file
128
python/sglang/srt/layers/attention/flashmla_backend.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Support attention backend for flashMLA.
|
||||
|
||||
Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
|
||||
#TODO
|
||||
Support FlashMLA decode with cudagraph
|
||||
Enable speculative sampling in FlashMLA
|
||||
Integrate FA3 prefill
|
||||
"""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
# FlashMLA only supports pagesize=64
|
||||
PAGE_SIZE = 64
|
||||
|
||||
|
||||
class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
||||
)
|
||||
|
||||
self.num_q_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.num_local_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.v_head_dim = model_runner.model_config.v_head_dim
|
||||
self.scaling = model_runner.model_config.scaling
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
):
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE)
|
||||
flashmla_index = torch.full(
|
||||
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
|
||||
)
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.indices_updater_decode.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
None,
|
||||
flashmla_index,
|
||||
self.indices_updater_decode.req_to_token.size(1),
|
||||
flashmla_index.size(1),
|
||||
max_seqlen_pad,
|
||||
)
|
||||
|
||||
mla_metadata, mla_splits = get_mla_metadata(
|
||||
forward_batch.seq_lens.to(torch.int32),
|
||||
1 * self.num_q_heads // self.num_kv_heads,
|
||||
self.num_kv_heads,
|
||||
)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=reshape_q,
|
||||
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
|
||||
block_table=flashmla_index,
|
||||
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
||||
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
||||
tile_scheduler_metadata=mla_metadata,
|
||||
num_splits=mla_splits,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
@@ -15,6 +15,7 @@ def create_flashinfer_kv_indices_triton(
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
@@ -37,3 +38,56 @@ def create_flashinfer_kv_indices_triton(
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashmla_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
kv_indices_ptr_stride: tl.constexpr,
|
||||
max_pagesize: tl.constexpr,
|
||||
):
|
||||
PAGED_SIZE: tl.constexpr = 64
|
||||
BLOCK_SIZE: tl.constexpr = 4096
|
||||
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# find the req pool idx, this is for batch to token
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
|
||||
kv_start = 0
|
||||
kv_end = 0
|
||||
if kv_start_idx:
|
||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||
kv_end = kv_start
|
||||
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
|
||||
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
|
||||
for i in range(num_pages_loop):
|
||||
paged_offset = (
|
||||
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
||||
) * PAGED_SIZE
|
||||
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
||||
|
||||
mask = paged_offset <= num_paged * PAGED_SIZE
|
||||
mask_out = paged_offset_out <= num_paged
|
||||
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ paged_offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,
|
||||
data // PAGED_SIZE,
|
||||
mask=mask_out,
|
||||
)
|
||||
|
||||
@@ -71,6 +71,7 @@ global_server_args_dict = {
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
}
|
||||
@@ -1273,7 +1274,10 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
if (
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
):
|
||||
decode_seq_lens = self.seq_lens.cpu()
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
|
||||
@@ -149,6 +149,7 @@ class ModelRunner:
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"enable_flashmla": server_args.enable_flashmla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
@@ -223,6 +224,9 @@ class ModelRunner:
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
server_args.attention_backend = "flashinfer_mla"
|
||||
elif server_args.enable_flashmla:
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
server_args.attention_backend = "flashmla"
|
||||
else:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
server_args.attention_backend = "triton"
|
||||
@@ -840,6 +844,10 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||
|
||||
self.attn_backend = FlashMLABackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
@@ -173,6 +173,7 @@ class ServerArgs:
|
||||
tool_call_parser: str = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
enable_flashinfer_mla: bool = False
|
||||
enable_flashmla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
|
||||
@@ -227,6 +228,8 @@ class ServerArgs:
|
||||
|
||||
assert self.chunked_prefill_size % self.page_size == 0
|
||||
|
||||
if self.enable_flashmla is True:
|
||||
assert self.page_size == 64, "FlashMLA only support page_size=64"
|
||||
# Set cuda graph max batch size
|
||||
if self.cuda_graph_max_bs is None:
|
||||
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
||||
@@ -753,6 +756,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable FlashInfer MLA optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashmla",
|
||||
action="store_true",
|
||||
help="Enable FlashMLA decode optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flashinfer-mla-disable-ragged",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user