[Feature]Support ragged prefill in flashinfer mla backend (#3967)
Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
@@ -133,7 +133,6 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
|
|
||||||
* `attention_backend`: The backend for attention computation and KV cache management.
|
* `attention_backend`: The backend for attention computation and KV cache management.
|
||||||
* `sampling_backend`: The backend for sampling.
|
* `sampling_backend`: The backend for sampling.
|
||||||
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. (In Experiment Stage)
|
|
||||||
|
|
||||||
## Constrained Decoding
|
## Constrained Decoding
|
||||||
|
|
||||||
@@ -186,3 +185,5 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
|
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
|
||||||
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
|
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
|
||||||
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
|
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
|
||||||
|
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
|
||||||
|
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ if is_flashinfer_available():
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
@@ -47,16 +46,12 @@ class WrapperDispatch(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodeMetadata:
|
class DecodeMetadata:
|
||||||
decode_wrappers: List[
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||||
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrefillMetadata:
|
class PrefillMetadata:
|
||||||
prefill_wrappers: List[
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||||
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
|
||||||
]
|
|
||||||
use_ragged: bool
|
use_ragged: bool
|
||||||
extend_no_prefix: bool
|
extend_no_prefix: bool
|
||||||
|
|
||||||
@@ -109,12 +104,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||||
|
|
||||||
self.enable_flashinfer_mla = False
|
|
||||||
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
|
||||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
|
||||||
self.enable_flashinfer_mla = True
|
|
||||||
global_config.enable_flashinfer_mla = True
|
|
||||||
|
|
||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
global global_workspace_buffer
|
global global_workspace_buffer
|
||||||
if global_workspace_buffer is None:
|
if global_workspace_buffer is None:
|
||||||
@@ -132,13 +121,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
for _ in range(self.num_wrappers)
|
for _ in range(self.num_wrappers)
|
||||||
]
|
]
|
||||||
if self.enable_flashinfer_mla:
|
|
||||||
self.qo_indptr = [
|
|
||||||
torch.zeros(
|
|
||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
||||||
)
|
|
||||||
for _ in range(self.num_wrappers)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
assert self.num_wrappers == 1
|
assert self.num_wrappers == 1
|
||||||
self.kv_indptr = [kv_indptr_buf]
|
self.kv_indptr = [kv_indptr_buf]
|
||||||
@@ -162,24 +144,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.decode_wrappers = []
|
self.decode_wrappers = []
|
||||||
for _ in range(self.num_wrappers):
|
for _ in range(self.num_wrappers):
|
||||||
if not skip_prefill:
|
if not skip_prefill:
|
||||||
if (
|
|
||||||
self.enable_flashinfer_mla
|
|
||||||
and not global_server_args_dict["disable_radix_cache"]
|
|
||||||
):
|
|
||||||
# use mla paged prefill
|
|
||||||
self.prefill_wrappers_paged.append(
|
|
||||||
BatchMLAPagedAttentionWrapper(
|
|
||||||
self.workspace_buffer,
|
|
||||||
backend="fa2",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.prefill_wrappers_verify.append(
|
|
||||||
BatchMLAPagedAttentionWrapper(
|
|
||||||
self.workspace_buffer,
|
|
||||||
backend="fa2",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.prefill_wrappers_paged.append(
|
self.prefill_wrappers_paged.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -188,15 +152,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.prefill_wrappers_verify.append(
|
self.prefill_wrappers_verify.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
self.workspace_buffer, "NHD"
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if self.enable_flashinfer_mla:
|
|
||||||
self.decode_wrappers.append(
|
|
||||||
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.decode_wrappers.append(
|
self.decode_wrappers.append(
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -259,10 +217,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
if self.is_multimodal or (
|
if self.is_multimodal:
|
||||||
self.enable_flashinfer_mla
|
|
||||||
and not global_server_args_dict["disable_radix_cache"]
|
|
||||||
):
|
|
||||||
use_ragged = False
|
use_ragged = False
|
||||||
extend_no_prefix = False
|
extend_no_prefix = False
|
||||||
else:
|
else:
|
||||||
@@ -321,19 +276,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
for i in range(self.num_wrappers):
|
for i in range(self.num_wrappers):
|
||||||
if self.enable_flashinfer_mla:
|
|
||||||
decode_wrappers.append(
|
|
||||||
BatchMLAPagedAttentionWrapper(
|
|
||||||
self.workspace_buffer,
|
|
||||||
use_cuda_graph=True,
|
|
||||||
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
|
|
||||||
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
|
|
||||||
kv_indices=self.cuda_graph_kv_indices[i],
|
|
||||||
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
|
||||||
backend="fa2",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
decode_wrappers.append(
|
decode_wrappers.append(
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -347,6 +289,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -435,56 +378,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
if global_config.enable_flashinfer_mla:
|
|
||||||
cache_loc = (
|
|
||||||
forward_batch.out_cache_loc
|
|
||||||
if not layer.is_cross_attention
|
|
||||||
else forward_batch.encoder_out_cache_loc
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_soft_cap = layer.logit_cap
|
|
||||||
|
|
||||||
if global_server_args_dict["disable_radix_cache"]:
|
|
||||||
# use mla ragged prefill
|
|
||||||
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
|
||||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
|
||||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
|
||||||
causal=True,
|
|
||||||
sm_scale=layer.scaling,
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer,
|
|
||||||
cache_loc,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# use mla paged prefill
|
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
|
||||||
self._get_wrapper_idx(layer)
|
|
||||||
]
|
|
||||||
if k is not None:
|
|
||||||
assert v is not None
|
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer, cache_loc, k, v
|
|
||||||
)
|
|
||||||
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
||||||
|
|
||||||
o = prefill_wrapper_paged.run(
|
|
||||||
qall[:, :, : layer.v_head_dim],
|
|
||||||
qall[:, :, layer.v_head_dim :],
|
|
||||||
k_buf[:, :, : layer.v_head_dim],
|
|
||||||
k_buf[:, :, layer.v_head_dim :],
|
|
||||||
)
|
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
||||||
else:
|
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
@@ -562,28 +455,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else forward_batch.encoder_out_cache_loc
|
else forward_batch.encoder_out_cache_loc
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_flashinfer_mla:
|
|
||||||
if k is not None:
|
|
||||||
assert v is not None
|
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer,
|
|
||||||
cache_loc,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
)
|
|
||||||
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
||||||
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
|
||||||
o = decode_wrapper.run(
|
|
||||||
reshaped_q[:, :, : layer.v_head_dim],
|
|
||||||
reshaped_q[:, :, layer.v_head_dim :],
|
|
||||||
reshaped_k[:, :, : layer.v_head_dim],
|
|
||||||
reshaped_k[:, :, layer.v_head_dim :],
|
|
||||||
)
|
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
||||||
else:
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
@@ -648,9 +519,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
|
||||||
],
|
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
@@ -662,9 +531,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
|
||||||
],
|
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
@@ -745,9 +612,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper: Union[
|
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
||||||
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
|
||||||
],
|
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
@@ -775,25 +640,6 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
if global_config.enable_flashinfer_mla:
|
|
||||||
sm_scale = 1.0 / math.sqrt(192)
|
|
||||||
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
|
||||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
|
||||||
wrapper.plan(
|
|
||||||
q_indptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_lens,
|
|
||||||
self.num_qo_heads,
|
|
||||||
512,
|
|
||||||
64,
|
|
||||||
1,
|
|
||||||
False,
|
|
||||||
sm_scale,
|
|
||||||
self.data_type,
|
|
||||||
self.data_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -845,9 +691,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
|
||||||
],
|
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
@@ -861,9 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
|
||||||
],
|
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
@@ -972,9 +814,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
wrapper_paged: Union[
|
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
|
||||||
],
|
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
@@ -1020,17 +860,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
# extend part
|
# extend part
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
if global_config.enable_flashinfer_mla:
|
|
||||||
wrapper_ragged.begin_forward(
|
|
||||||
qo_indptr=qo_indptr,
|
|
||||||
kv_indptr=qo_indptr,
|
|
||||||
num_qo_heads=self.num_qo_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
head_dim_qk=192,
|
|
||||||
head_dim_vo=128,
|
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
wrapper_ragged.begin_forward(
|
wrapper_ragged.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -1040,7 +869,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not global_config.enable_flashinfer_mla:
|
|
||||||
# cached part
|
# cached part
|
||||||
wrapper_paged.begin_forward(
|
wrapper_paged.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -1055,26 +883,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
custom_mask=custom_mask,
|
custom_mask=custom_mask,
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
global_config.enable_flashinfer_mla
|
|
||||||
and not global_server_args_dict["disable_radix_cache"]
|
|
||||||
):
|
|
||||||
# mla paged prefill
|
|
||||||
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
|
||||||
wrapper_paged.plan(
|
|
||||||
qo_indptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_len_arr,
|
|
||||||
self.num_qo_heads,
|
|
||||||
512,
|
|
||||||
64,
|
|
||||||
1,
|
|
||||||
True,
|
|
||||||
1 / math.sqrt(192),
|
|
||||||
self.data_type,
|
|
||||||
self.data_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMultiStepDraftBackend:
|
class FlashInferMultiStepDraftBackend:
|
||||||
|
|||||||
@@ -2,13 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Support attention backend for flashinfer MLA.
|
Support attention backend for flashinfer MLA.
|
||||||
When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding.
|
The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
|
||||||
When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache),
|
When it's set to false, all wrappers are BatchMLAPaged wrapper.
|
||||||
|
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
|
||||||
and uses BatchMLAPaged wrapper for decoding.
|
and uses BatchMLAPaged wrapper for decoding.
|
||||||
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
@@ -18,7 +18,6 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
should_use_tensor_core,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -32,11 +31,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchMLAPagedAttentionWrapper,
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -46,9 +44,7 @@ class DecodeMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrefillMetadata:
|
class PrefillMetadata:
|
||||||
prefill_wrapper: Union[
|
prefill_wrapper: BatchMLAPagedAttentionWrapper
|
||||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
|
||||||
]
|
|
||||||
use_ragged: bool
|
use_ragged: bool
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +58,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -82,12 +77,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
self.workspace_buffer = global_workspace_buffer
|
self.workspace_buffer = global_workspace_buffer
|
||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
if kv_indptr_buf is None:
|
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.kv_indptr = kv_indptr_buf
|
|
||||||
|
|
||||||
self.qo_indptr = torch.zeros(
|
self.qo_indptr = torch.zeros(
|
||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
@@ -97,22 +89,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.q_indptr_decode = torch.arange(
|
||||||
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
|
||||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
self.workspace_buffer, "NHD"
|
self.workspace_buffer, "NHD"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not global_server_args_dict["disable_radix_cache"]:
|
|
||||||
# use mla paged prefill
|
|
||||||
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
backend="auto",
|
backend="auto",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
|
||||||
self.workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
backend="auto",
|
|
||||||
)
|
|
||||||
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||||
self.workspace_buffer, backend="auto"
|
self.workspace_buffer, backend="auto"
|
||||||
)
|
)
|
||||||
@@ -141,7 +130,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
use_ragged = global_server_args_dict["disable_radix_cache"]
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
use_ragged = (
|
||||||
|
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||||
|
and extend_no_prefix
|
||||||
|
)
|
||||||
|
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
@@ -241,45 +234,37 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
cache_loc = forward_batch.out_cache_loc
|
cache_loc = forward_batch.out_cache_loc
|
||||||
logits_soft_cap = layer.logit_cap
|
logits_soft_cap = layer.logit_cap
|
||||||
|
|
||||||
if not global_server_args_dict["disable_radix_cache"]:
|
|
||||||
# use mla paged prefill
|
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||||
if k is not None:
|
|
||||||
assert v is not None
|
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
|
||||||
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
|
# Save kv cache
|
||||||
|
if save_kv_cache and k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
|
|
||||||
|
if self.forward_metadata.use_ragged:
|
||||||
|
# ragged prefill
|
||||||
|
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
|
qall,
|
||||||
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
|
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# mla paged prefill
|
||||||
o = prefill_wrapper_paged.run(
|
o = prefill_wrapper_paged.run(
|
||||||
qall[:, :, : layer.v_head_dim],
|
qall[:, :, : layer.v_head_dim],
|
||||||
qall[:, :, layer.v_head_dim :],
|
qall[:, :, layer.v_head_dim :],
|
||||||
k_buf[:, :, : layer.v_head_dim],
|
k_buf[:, :, : layer.v_head_dim],
|
||||||
k_buf[:, :, layer.v_head_dim :],
|
k_buf[:, :, layer.v_head_dim :],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# use mla ragged prefill
|
|
||||||
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
|
||||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
|
||||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
|
||||||
causal=True,
|
|
||||||
sm_scale=layer.scaling,
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# FIXME: Here should be another prefill_paged to call
|
|
||||||
|
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer,
|
|
||||||
cache_loc,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
@@ -334,6 +319,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
self.kv_indptr = attn_backend.kv_indptr
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.q_indptr = attn_backend.q_indptr_decode
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -342,12 +328,13 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrapper or self.decode_wrapper
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
decode_wrapper,
|
decode_wrapper,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
|
self.q_indptr,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -357,14 +344,19 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
|
q_indptr: torch.Tensor,
|
||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
|
q_indptr = q_indptr[: bs + 1]
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||||
|
sm_scale = self.scaling
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -375,9 +367,6 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
sm_scale = self.scaling
|
|
||||||
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
|
||||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
q_indptr,
|
q_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
@@ -397,12 +386,9 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
class FlashInferMLAIndicesUpdaterPrefill:
|
class FlashInferMLAIndicesUpdaterPrefill:
|
||||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
# Parse Constants
|
# Parse Constants
|
||||||
self.num_qo_heads = (
|
self.num_local_heads = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
)
|
)
|
||||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
|
||||||
get_attention_tp_size()
|
|
||||||
)
|
|
||||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||||
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||||
@@ -425,9 +411,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrapper_paged: Union[
|
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
|
||||||
],
|
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
@@ -453,9 +437,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
wrapper_paged: Union[
|
wrapper_paged: BatchMLAPagedAttentionWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
|
||||||
],
|
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
@@ -466,7 +448,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
# Normal extend
|
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
@@ -488,19 +469,18 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
sm_scale = self.scaling
|
sm_scale = self.scaling
|
||||||
|
|
||||||
# extend part
|
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
|
# ragged prefill
|
||||||
wrapper_ragged.begin_forward(
|
wrapper_ragged.begin_forward(
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
kv_indptr=qo_indptr,
|
kv_indptr=qo_indptr,
|
||||||
num_qo_heads=self.num_qo_heads,
|
num_qo_heads=self.num_local_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
head_dim_vo=self.v_head_dim,
|
head_dim_vo=self.v_head_dim,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
if not global_server_args_dict["disable_radix_cache"]:
|
|
||||||
# mla paged prefill
|
# mla paged prefill
|
||||||
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
||||||
wrapper_paged.plan(
|
wrapper_paged.plan(
|
||||||
@@ -508,7 +488,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_len_arr,
|
kv_len_arr,
|
||||||
self.num_qo_heads,
|
self.num_local_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.qk_rope_head_dim,
|
self.qk_rope_head_dim,
|
||||||
1,
|
1,
|
||||||
@@ -517,5 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
self.q_data_type,
|
self.q_data_type,
|
||||||
self.data_type,
|
self.data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: Here should be some logic for prefill paged when not using radix cache?
|
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ global_server_args_dict = {
|
|||||||
"device": ServerArgs.device,
|
"device": ServerArgs.device,
|
||||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||||
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ class ModelRunner:
|
|||||||
"device": server_args.device,
|
"device": server_args.device,
|
||||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||||
"disable_radix_cache": server_args.disable_radix_cache,
|
"disable_radix_cache": server_args.disable_radix_cache,
|
||||||
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
def no_absorb() -> bool:
|
def no_absorb() -> bool:
|
||||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||||
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||||
return (
|
return (
|
||||||
global_server_args_dict["disable_radix_cache"]
|
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||||
and forward_batch.forward_mode.is_extend()
|
and forward_batch.forward_mode.is_extend()
|
||||||
|
and forward_batch.extend_prefix_lens.sum() == 0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ class ServerArgs:
|
|||||||
tool_call_parser: str = None
|
tool_call_parser: str = None
|
||||||
enable_hierarchical_cache: bool = False
|
enable_hierarchical_cache: bool = False
|
||||||
enable_flashinfer_mla: bool = False
|
enable_flashinfer_mla: bool = False
|
||||||
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
@@ -713,6 +714,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable FlashInfer MLA optimization",
|
help="Enable FlashInfer MLA optimization",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--flashinfer-mla-disable-ragged",
|
||||||
|
action="store_true",
|
||||||
|
help="Not using ragged prefill wrapper when running flashinfer mla",
|
||||||
|
)
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ suites = {
|
|||||||
"test_gguf.py",
|
"test_gguf.py",
|
||||||
"test_input_embeddings.py",
|
"test_input_embeddings.py",
|
||||||
"test_mla.py",
|
"test_mla.py",
|
||||||
|
"test_mla_flashinfer.py",
|
||||||
"test_mla_fp8.py",
|
"test_mla_fp8.py",
|
||||||
"test_json_constrained.py",
|
"test_json_constrained.py",
|
||||||
"test_large_max_new_tokens.py",
|
"test_large_max_new_tokens.py",
|
||||||
|
|||||||
104
test/srt/test_mla_flashinfer.py
Normal file
104
test/srt/test_mla_flashinfer.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashinferMLA(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "sgl-project/sglang-ci-dsv3-test"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
other_args = ["--trust-remote-code"]
|
||||||
|
if torch.cuda.is_available() and torch.version.cuda:
|
||||||
|
other_args.extend(
|
||||||
|
[
|
||||||
|
"--enable-torch-compile",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
"--enable-flashinfer-mla",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.62)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashinferMLANoRagged(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "sgl-project/sglang-ci-dsv3-test"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
other_args = ["--trust-remote-code"]
|
||||||
|
if torch.cuda.is_available() and torch.version.cuda:
|
||||||
|
other_args.extend(
|
||||||
|
[
|
||||||
|
"--enable-torch-compile",
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
"--enable-flashinfer-mla",
|
||||||
|
"--flashinfer-mla-disable-ragged",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.62)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user