[feat] Support different attention backends for prefill and decode (#6338)
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
100
python/sglang/srt/layers/attention/hybrid_attn_backend.py
Normal file
100
python/sglang/srt/layers/attention/hybrid_attn_backend.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
class HybridAttnBackend(AttentionBackend):
|
||||
"""Support different backends for prefill and decode."""
|
||||
|
||||
def __init__(
|
||||
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
|
||||
):
|
||||
self.prefill_backend = prefill_backend
|
||||
self.decode_backend = decode_backend
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
self.decode_backend.init_forward_metadata(forward_batch)
|
||||
else:
|
||||
self.prefill_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
seq_lens_cpu,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
return self.decode_backend.forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
return self.prefill_backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
@@ -1690,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
||||
else:
|
||||
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] == "fa3"
|
||||
attention_backend_str == "fa3"
|
||||
or (
|
||||
global_server_args_dict["use_mla_backend"]
|
||||
and global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
and attention_backend_str == "flashinfer"
|
||||
)
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
||||
or global_server_args_dict["attention_backend"] == "ascend"
|
||||
or attention_backend_str == "flashmla"
|
||||
or attention_backend_str == "cutlass_mla"
|
||||
or attention_backend_str == "ascend"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = (
|
||||
|
||||
@@ -1308,9 +1308,58 @@ class ModelRunner:
|
||||
else:
|
||||
self.attn_backend = self._get_attention_backend()
|
||||
|
||||
# TODO unify with 6338
|
||||
def _get_attention_backend(self):
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
"""Init attention kernel backend."""
|
||||
self.decode_attention_backend_str = (
|
||||
self.server_args.decode_attention_backend
|
||||
if self.server_args.decode_attention_backend
|
||||
else self.server_args.attention_backend
|
||||
)
|
||||
self.prefill_attention_backend_str = (
|
||||
self.server_args.prefill_attention_backend
|
||||
if self.server_args.prefill_attention_backend
|
||||
else self.server_args.attention_backend
|
||||
)
|
||||
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
||||
assert (
|
||||
self.server_args.speculative_algorithm is None
|
||||
), "Currently HybridAttentionBackend does not support speculative decoding."
|
||||
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
||||
HybridAttnBackend,
|
||||
)
|
||||
|
||||
attn_backend = HybridAttnBackend(
|
||||
decode_backend=self._get_attention_backend_from_str(
|
||||
self.decode_attention_backend_str
|
||||
),
|
||||
prefill_backend=self._get_attention_backend_from_str(
|
||||
self.prefill_attention_backend_str
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"Using hybrid attention backend for decode and prefill: "
|
||||
f"decode_backend={self.decode_attention_backend_str}, "
|
||||
f"prefill_backend={self.prefill_attention_backend_str}."
|
||||
)
|
||||
logger.warning(
|
||||
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
||||
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
||||
)
|
||||
else:
|
||||
attn_backend = self._get_attention_backend_from_str(
|
||||
self.server_args.attention_backend
|
||||
)
|
||||
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"decode_attention_backend": self.decode_attention_backend_str,
|
||||
"prefill_attention_backend": self.prefill_attention_backend_str,
|
||||
}
|
||||
)
|
||||
return attn_backend
|
||||
|
||||
def _get_attention_backend_from_str(self, backend_str: str):
|
||||
if backend_str == "flashinfer":
|
||||
if not self.use_mla_backend:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
@@ -1318,7 +1367,11 @@ class ModelRunner:
|
||||
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
if (
|
||||
not hasattr(self, "plan_stream_for_flashinfer")
|
||||
or not self.plan_stream_for_flashinfer
|
||||
):
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
return FlashInferAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
@@ -1326,15 +1379,15 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return FlashInferMLAAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "aiter":
|
||||
elif backend_str == "aiter":
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
|
||||
return AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "ascend":
|
||||
elif backend_str == "ascend":
|
||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||
|
||||
return AscendAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
elif backend_str == "triton":
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
@@ -1349,17 +1402,17 @@ class ModelRunner:
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
|
||||
return TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
elif backend_str == "torch_native":
|
||||
from sglang.srt.layers.attention.torch_native_backend import (
|
||||
TorchNativeAttnBackend,
|
||||
)
|
||||
|
||||
return TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
elif backend_str == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||
|
||||
return FlashMLABackend(self)
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
elif backend_str == "fa3":
|
||||
assert (
|
||||
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
||||
) or torch.cuda.get_device_capability()[0] == 9, (
|
||||
@@ -1371,7 +1424,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return FlashAttentionBackend(self)
|
||||
elif self.server_args.attention_backend == "cutlass_mla":
|
||||
elif backend_str == "cutlass_mla":
|
||||
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
||||
CutlassMLABackend,
|
||||
)
|
||||
@@ -1385,9 +1438,7 @@ class ModelRunner:
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
raise ValueError(f"Invalid attention backend: {backend_str}")
|
||||
|
||||
def init_double_sparsity_channel_config(self, selected_channel):
|
||||
selected_channel = "." + selected_channel + "_proj"
|
||||
@@ -1475,7 +1526,10 @@ class ModelRunner:
|
||||
if self.support_pp:
|
||||
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward_extend(
|
||||
|
||||
@@ -925,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
self.attention_backend = global_server_args_dict["attention_backend"]
|
||||
|
||||
self.current_attention_backend = (
|
||||
None # Attention backend used by current forward batch
|
||||
)
|
||||
self.rocm_fused_decode_mla = get_bool_env_var(
|
||||
"SGLANG_ROCM_FUSED_DECODE_MLA", "false"
|
||||
)
|
||||
@@ -1009,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
if self.attention_backend == "ascend":
|
||||
# Determine attention backend used by current forward batch
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||
else:
|
||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||
self.current_attention_backend = attention_backend
|
||||
|
||||
if attention_backend == "ascend":
|
||||
return AttnForwardMethod.MLA
|
||||
elif self.attention_backend == "flashinfer":
|
||||
elif attention_backend == "flashinfer":
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
if (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
@@ -1023,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif self.attention_backend == "fa3":
|
||||
elif attention_backend == "fa3":
|
||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
if forward_batch.extend_prefix_lens_cpu is not None:
|
||||
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
||||
@@ -1040,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif self.attention_backend == "aiter":
|
||||
elif attention_backend == "aiter":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
@@ -1288,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
):
|
||||
if (
|
||||
self.attention_backend == "fa3"
|
||||
or self.attention_backend == "flashinfer"
|
||||
or self.attention_backend == "cutlass_mla"
|
||||
self.current_attention_backend == "fa3"
|
||||
or self.current_attention_backend == "flashinfer"
|
||||
or self.current_attention_backend == "cutlass_mla"
|
||||
):
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
|
||||
@@ -151,6 +151,8 @@ class ServerArgs:
|
||||
|
||||
# Kernel backend
|
||||
attention_backend: Optional[str] = None
|
||||
decode_attention_backend: Optional[str] = None
|
||||
prefill_attention_backend: Optional[str] = None
|
||||
sampling_backend: Optional[str] = None
|
||||
grammar_backend: Optional[str] = None
|
||||
mm_attention_backend: Optional[str] = None
|
||||
@@ -387,13 +389,19 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 128
|
||||
|
||||
if self.attention_backend == "flashmla":
|
||||
if (
|
||||
self.attention_backend == "flashmla"
|
||||
or self.decode_attention_backend == "flashmla"
|
||||
):
|
||||
logger.warning(
|
||||
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
||||
)
|
||||
self.page_size = 64
|
||||
|
||||
if self.attention_backend == "cutlass_mla":
|
||||
if (
|
||||
self.attention_backend == "cutlass_mla"
|
||||
or self.decode_attention_backend == "cutlass_mla"
|
||||
):
|
||||
logger.warning(
|
||||
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
||||
)
|
||||
@@ -1213,6 +1221,35 @@ class ServerArgs:
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
],
|
||||
default=ServerArgs.decode_attention_backend,
|
||||
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prefill-attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
],
|
||||
default=ServerArgs.prefill_attention_backend,
|
||||
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-backend",
|
||||
type=str,
|
||||
|
||||
@@ -491,6 +491,8 @@ class SRTRunner:
|
||||
lora_paths: List[str] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
attention_backend: Optional[str] = None,
|
||||
prefill_attention_backend: Optional[str] = None,
|
||||
decode_attention_backend: Optional[str] = None,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
@@ -540,6 +542,8 @@ class SRTRunner:
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
lora_backend=lora_backend,
|
||||
attention_backend=attention_backend,
|
||||
prefill_attention_backend=prefill_attention_backend,
|
||||
decode_attention_backend=decode_attention_backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
|
||||
Reference in New Issue
Block a user