enable prefix cache with dp (#10459)
This commit is contained in:
@@ -564,18 +564,6 @@ class ModelRunner:
|
|||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
server_args.disable_chunked_prefix_cache = True
|
server_args.disable_chunked_prefix_cache = True
|
||||||
|
|
||||||
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
|
|
||||||
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
|
|
||||||
elif (
|
|
||||||
self.dp_size > 1
|
|
||||||
and is_sm100_supported()
|
|
||||||
and server_args.attention_backend != "triton"
|
|
||||||
and server_args.attention_backend == "trtllm_mla"
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
|
|
||||||
)
|
|
||||||
server_args.disable_chunked_prefix_cache = True
|
|
||||||
if not server_args.disable_chunked_prefix_cache:
|
if not server_args.disable_chunked_prefix_cache:
|
||||||
logger.info("Chunked prefix cache is turned on.")
|
logger.info("Chunked prefix cache is turned on.")
|
||||||
|
|
||||||
|
|||||||
@@ -1098,7 +1098,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
||||||
) and self.flashinfer_mla_disable_ragged
|
) and self.flashinfer_mla_disable_ragged
|
||||||
|
|
||||||
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
|
||||||
if (
|
if (
|
||||||
not disable_ragged
|
not disable_ragged
|
||||||
and forward_batch.forward_mode.is_extend()
|
and forward_batch.forward_mode.is_extend()
|
||||||
@@ -1111,15 +1110,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
or sum_extend_prefix_lens == 0
|
or sum_extend_prefix_lens == 0
|
||||||
)
|
)
|
||||||
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
|
|
||||||
# dp case. Redirect to mla kernel as a workaround.
|
|
||||||
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
|
|
||||||
and not (
|
|
||||||
original_mode is not None
|
|
||||||
and original_mode.is_decode()
|
|
||||||
and is_sm100_supported()
|
|
||||||
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
@@ -1128,14 +1118,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
elif attention_backend == "trtllm_mla":
|
elif attention_backend == "trtllm_mla":
|
||||||
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
|
||||||
if (
|
|
||||||
original_mode is not None
|
|
||||||
and original_mode.is_decode()
|
|
||||||
and is_sm100_supported()
|
|
||||||
):
|
|
||||||
return _dispatch_mla_subtype()
|
|
||||||
|
|
||||||
sum_extend_prefix_lens = (
|
sum_extend_prefix_lens = (
|
||||||
sum(forward_batch.extend_prefix_lens_cpu)
|
sum(forward_batch.extend_prefix_lens_cpu)
|
||||||
if forward_batch.extend_prefix_lens_cpu is not None
|
if forward_batch.extend_prefix_lens_cpu is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user