[Revision] Replace enable_flashinfer_mla argument with attention_backend (#5052)
This commit is contained in:
@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.device = model_runner.device
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
# Allocate buffers
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
|
||||
@@ -76,7 +76,6 @@ global_server_args_dict = {
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
@@ -1437,7 +1436,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
(
|
||||
global_server_args_dict["use_mla_backend"]
|
||||
and global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
)
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
|
||||
@@ -75,6 +75,7 @@ from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
@@ -123,6 +124,10 @@ class ModelRunner:
|
||||
self.page_size = server_args.page_size
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.use_mla_backend = (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not server_args.disable_mla
|
||||
)
|
||||
|
||||
# Model-specific adjustment
|
||||
self.model_specific_adjustment()
|
||||
@@ -151,7 +156,6 @@ class ModelRunner:
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"enable_flashmla": server_args.enable_flashmla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
@@ -159,6 +163,7 @@ class ModelRunner:
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -219,27 +224,38 @@ class ModelRunner:
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not server_args.disable_mla
|
||||
):
|
||||
if server_args.enable_flashinfer_mla:
|
||||
# TODO: remove this branch after enable_flashinfer_mla is deprecated
|
||||
logger.info("MLA optimization is turned on. Use flashinfer backend.")
|
||||
server_args.attention_backend = "flashinfer"
|
||||
elif server_args.enable_flashmla:
|
||||
# TODO: remove this branch after enable_flashmla is deprecated
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
server_args.attention_backend = "flashmla"
|
||||
elif server_args.attention_backend is None:
|
||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
||||
if not self.use_mla_backend:
|
||||
server_args.attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
logger.info(
|
||||
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
||||
)
|
||||
elif self.use_mla_backend:
|
||||
# TODO: add MLA optimization on CPU
|
||||
if server_args.device != "cpu":
|
||||
if server_args.enable_flashinfer_mla:
|
||||
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
|
||||
logger.info(
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
server_args.attention_backend = "flashinfer_mla"
|
||||
elif server_args.enable_flashmla:
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
server_args.attention_backend = "flashmla"
|
||||
elif server_args.attention_backend == "fa3":
|
||||
logger.info(
|
||||
f"MLA optimization is turned on. Use flash attention 3 backend."
|
||||
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
||||
)
|
||||
else:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
server_args.attention_backend = "triton"
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"MLA optimization not supported on CPU.")
|
||||
|
||||
if server_args.enable_double_sparsity:
|
||||
logger.info(
|
||||
@@ -637,10 +653,7 @@ class ModelRunner:
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
):
|
||||
if self.use_mla_backend:
|
||||
cell_size = (
|
||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||
* self.model_config.num_hidden_layers
|
||||
@@ -751,10 +764,7 @@ class ModelRunner:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
):
|
||||
if self.use_mla_backend:
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
@@ -825,14 +835,21 @@ class ModelRunner:
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
)
|
||||
if not self.use_mla_backend:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
)
|
||||
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
@@ -858,12 +875,6 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||
|
||||
|
||||
@@ -686,7 +686,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
|
||||
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
@@ -694,7 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
|
||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||
if self.enable_flashinfer_mla:
|
||||
if self.attention_backend == "flashinfer":
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
|
||||
@@ -179,7 +179,7 @@ class ServerArgs:
|
||||
tool_call_parser: Optional[str] = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
hicache_ratio: float = 2.0
|
||||
enable_flashinfer_mla: bool = False
|
||||
enable_flashinfer_mla: bool = False # TODO: remove this argument
|
||||
enable_flashmla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
@@ -267,15 +267,11 @@ class ServerArgs:
|
||||
else:
|
||||
self.cuda_graph_max_bs = 160
|
||||
|
||||
# Choose kernel backends
|
||||
# Set kernel backends for hpu device
|
||||
if self.device == "hpu":
|
||||
self.attention_backend = "torch_native"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "pytorch"
|
||||
@@ -842,7 +838,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-mla",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer MLA optimization",
|
||||
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashmla",
|
||||
|
||||
@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -146,15 +150,26 @@ class EAGLEWorker(TpModelWorker):
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
@@ -171,19 +186,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionMultiStepBackend,
|
||||
|
||||
Reference in New Issue
Block a user