feat: support flashinfer mla attention for deepseek v3 (#3550)
This commit is contained in:
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
@@ -20,6 +21,7 @@ import triton.language as tl
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
@@ -35,7 +37,7 @@ if is_flashinfer_available():
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import PosEncodingMode
|
||||
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
|
||||
|
||||
@dataclass
|
||||
class DecodeMetadata:
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||
decode_wrappers: List[
|
||||
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||
|
||||
self.enable_flashinfer_mla = False
|
||||
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
self.enable_flashinfer_mla = True
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
# Allocate buffers
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
if self.enable_flashinfer_mla:
|
||||
self.qo_indptr = [
|
||||
torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
else:
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_indptr = [kv_indptr_buf]
|
||||
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
self.decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
if self.enable_flashinfer_mla:
|
||||
self.decode_wrappers.append(
|
||||
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
||||
)
|
||||
else:
|
||||
self.decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Create indices updater
|
||||
if not skip_prefill:
|
||||
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
||||
:num_tokens
|
||||
],
|
||||
if self.enable_flashinfer_mla:
|
||||
decode_wrappers.append(
|
||||
BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
|
||||
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
|
||||
kv_indices=self.cuda_graph_kv_indices[i],
|
||||
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
||||
backend="fa2",
|
||||
)
|
||||
)
|
||||
else:
|
||||
decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
||||
:num_tokens
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices,
|
||||
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
if not self.forward_metadata.use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
if global_config.enable_flashinfer_mla:
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if self.forward_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
o = o1
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
if not self.forward_metadata.use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if self.forward_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
if self.enable_flashinfer_mla:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
||||
o = decode_wrapper.run(
|
||||
reshaped_q[:, :, : layer.v_head_dim],
|
||||
reshaped_q[:, :, layer.v_head_dim :],
|
||||
reshaped_k[:, :, : layer.v_head_dim],
|
||||
reshaped_k[:, :, layer.v_head_dim :],
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def _get_wrapper_idx(self, layer: RadixAttention):
|
||||
if self.num_wrappers == 1:
|
||||
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
decode_wrappers: List[
|
||||
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||
],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
decode_wrappers: List[
|
||||
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||
],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
||||
wrapper: Union[
|
||||
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||
],
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
data_type=self.data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
non_blocking=True,
|
||||
)
|
||||
if global_config.enable_flashinfer_mla:
|
||||
sm_scale = 1.0 / math.sqrt(192)
|
||||
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_lens,
|
||||
self.num_qo_heads,
|
||||
512,
|
||||
64,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
)
|
||||
else:
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
data_type=self.data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferIndicesUpdaterPrefill:
|
||||
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
|
||||
# extend part
|
||||
if use_ragged:
|
||||
wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
if global_config.enable_flashinfer_mla:
|
||||
wrapper_ragged.begin_forward(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=qo_indptr,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim_qk=192,
|
||||
head_dim_vo=128,
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
else:
|
||||
wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if not global_config.enable_flashinfer_mla:
|
||||
# cached part
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
custom_mask=custom_mask,
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
# cached part
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
custom_mask=custom_mask,
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMultiStepDraftBackend:
|
||||
"""
|
||||
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
stream.cuda_stream,
|
||||
|
||||
Reference in New Issue
Block a user