Fix chunked prefix cache for nvfp4 (#10180)
Co-authored-by: Elfie Guo <elfieg@nvidia.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import (
|
|||||||
create_flashmla_kv_indices_triton,
|
create_flashmla_kv_indices_triton,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
@@ -72,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
super().__init__(
|
||||||
|
model_runner,
|
||||||
|
skip_prefill,
|
||||||
|
kv_indptr_buf,
|
||||||
|
q_indptr_decode_buf,
|
||||||
|
)
|
||||||
|
|
||||||
config = model_runner.model_config
|
config = model_runner.model_config
|
||||||
|
|
||||||
@@ -112,6 +118,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
||||||
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||||
|
|
||||||
|
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||||
|
"disable_chunked_prefix_cache"
|
||||||
|
]
|
||||||
|
|
||||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
||||||
@@ -301,6 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
):
|
):
|
||||||
|
if self.disable_chunked_prefix_cache:
|
||||||
|
super().init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
||||||
cum_seq_lens_q = torch.cat(
|
cum_seq_lens_q = torch.cat(
|
||||||
(
|
(
|
||||||
@@ -540,6 +553,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
return super().forward_extend(
|
return super().forward_extend(
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
)
|
)
|
||||||
|
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
|
||||||
|
if forward_batch.attn_attend_prefix_cache is None:
|
||||||
|
return super().forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
|
)
|
||||||
|
|
||||||
if not forward_batch.attn_attend_prefix_cache:
|
if not forward_batch.attn_attend_prefix_cache:
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|||||||
@@ -560,18 +560,19 @@ class ModelRunner:
|
|||||||
|
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
server_args.disable_chunked_prefix_cache = True
|
server_args.disable_chunked_prefix_cache = True
|
||||||
|
|
||||||
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
|
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
|
||||||
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
|
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
|
||||||
elif (
|
elif (
|
||||||
self.dp_size > 1
|
self.dp_size > 1
|
||||||
and is_sm100_supported()
|
and is_sm100_supported()
|
||||||
and server_args.attention_backend != "triton"
|
and server_args.attention_backend != "triton"
|
||||||
|
and server_args.attention_backend == "trtllm_mla"
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
|
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
|
||||||
)
|
)
|
||||||
server_args.disable_chunked_prefix_cache = True
|
server_args.disable_chunked_prefix_cache = True
|
||||||
|
|
||||||
if not server_args.disable_chunked_prefix_cache:
|
if not server_args.disable_chunked_prefix_cache:
|
||||||
logger.info("Chunked prefix cache is turned on.")
|
logger.info("Chunked prefix cache is turned on.")
|
||||||
|
|
||||||
|
|||||||
@@ -1087,6 +1087,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
disable_ragged = (
|
disable_ragged = (
|
||||||
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
||||||
) and self.flashinfer_mla_disable_ragged
|
) and self.flashinfer_mla_disable_ragged
|
||||||
|
|
||||||
|
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
||||||
if (
|
if (
|
||||||
not disable_ragged
|
not disable_ragged
|
||||||
and forward_batch.forward_mode.is_extend()
|
and forward_batch.forward_mode.is_extend()
|
||||||
@@ -1099,15 +1101,40 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
or sum_extend_prefix_lens == 0
|
or sum_extend_prefix_lens == 0
|
||||||
)
|
)
|
||||||
|
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
|
||||||
|
# dp case. Redirect to mla kernel as a workaround.
|
||||||
|
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
|
||||||
|
and not (
|
||||||
|
original_mode is not None
|
||||||
|
and original_mode.is_decode()
|
||||||
|
and is_sm100_supported()
|
||||||
|
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
elif attention_backend == "trtllm_mla":
|
elif attention_backend == "trtllm_mla":
|
||||||
|
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
||||||
|
if (
|
||||||
|
original_mode is not None
|
||||||
|
and original_mode.is_decode()
|
||||||
|
and is_sm100_supported()
|
||||||
|
):
|
||||||
|
return _dispatch_mla_subtype()
|
||||||
|
|
||||||
|
sum_extend_prefix_lens = (
|
||||||
|
sum(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
if forward_batch.extend_prefix_lens_cpu is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
and (
|
||||||
|
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user