[feat] Support different attention backends for prefill and decode (#6338)
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -188,6 +188,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| Arguments | Description | Defaults |
|
| Arguments | Description | Defaults |
|
||||||
|-----------|-------------|----------|
|
|-----------|-------------|----------|
|
||||||
| `--attention-backend` | Choose the kernels for attention layers. | None |
|
| `--attention-backend` | Choose the kernels for attention layers. | None |
|
||||||
|
| `decode_attention_backend` | (Experimental) This argument specifies the backend for decode attention computation. Note that this argument has priority over `attention_backend`. | None |
|
||||||
|
| `prefill_attention_backend` | (Experimental) This argument specifies the backend for prefill attention computation. Note that this argument has priority over `attention_backend`. | None |
|
||||||
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
|
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
|
||||||
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
|
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
|
||||||
| `--mm-attention-backend` | Set multimodal attention backend. | None |
|
| `--mm-attention-backend` | Set multimodal attention backend. | None |
|
||||||
|
|||||||
100
python/sglang/srt/layers/attention/hybrid_attn_backend.py
Normal file
100
python/sglang/srt/layers/attention/hybrid_attn_backend.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
|
class HybridAttnBackend(AttentionBackend):
|
||||||
|
"""Support different backends for prefill and decode."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
|
||||||
|
):
|
||||||
|
self.prefill_backend = prefill_backend
|
||||||
|
self.decode_backend = decode_backend
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
if forward_batch.forward_mode.is_decode():
|
||||||
|
self.decode_backend.init_forward_metadata(forward_batch)
|
||||||
|
else:
|
||||||
|
self.prefill_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
):
|
||||||
|
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
|
bs,
|
||||||
|
num_tokens,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
encoder_lens,
|
||||||
|
forward_mode,
|
||||||
|
spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
bs,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
encoder_lens,
|
||||||
|
forward_mode,
|
||||||
|
spec_info,
|
||||||
|
seq_lens_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return self.decode_backend.forward_decode(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return self.prefill_backend.forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||||
|
)
|
||||||
@@ -1690,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
|
|
||||||
|
if self.forward_mode.is_decode_or_idle():
|
||||||
|
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
||||||
|
else:
|
||||||
|
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
||||||
# Create seq_lens_cpu when needed
|
# Create seq_lens_cpu when needed
|
||||||
if (
|
if (
|
||||||
global_server_args_dict["attention_backend"] == "fa3"
|
attention_backend_str == "fa3"
|
||||||
or (
|
or (
|
||||||
global_server_args_dict["use_mla_backend"]
|
global_server_args_dict["use_mla_backend"]
|
||||||
and global_server_args_dict["attention_backend"] == "flashinfer"
|
and attention_backend_str == "flashinfer"
|
||||||
)
|
)
|
||||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
or attention_backend_str == "flashmla"
|
||||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
or attention_backend_str == "cutlass_mla"
|
||||||
or global_server_args_dict["attention_backend"] == "ascend"
|
or attention_backend_str == "ascend"
|
||||||
or global_server_args_dict["enable_two_batch_overlap"]
|
or global_server_args_dict["enable_two_batch_overlap"]
|
||||||
):
|
):
|
||||||
seq_lens_cpu = (
|
seq_lens_cpu = (
|
||||||
|
|||||||
@@ -1308,9 +1308,58 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
self.attn_backend = self._get_attention_backend()
|
self.attn_backend = self._get_attention_backend()
|
||||||
|
|
||||||
# TODO unify with 6338
|
|
||||||
def _get_attention_backend(self):
|
def _get_attention_backend(self):
|
||||||
if self.server_args.attention_backend == "flashinfer":
|
"""Init attention kernel backend."""
|
||||||
|
self.decode_attention_backend_str = (
|
||||||
|
self.server_args.decode_attention_backend
|
||||||
|
if self.server_args.decode_attention_backend
|
||||||
|
else self.server_args.attention_backend
|
||||||
|
)
|
||||||
|
self.prefill_attention_backend_str = (
|
||||||
|
self.server_args.prefill_attention_backend
|
||||||
|
if self.server_args.prefill_attention_backend
|
||||||
|
else self.server_args.attention_backend
|
||||||
|
)
|
||||||
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
||||||
|
assert (
|
||||||
|
self.server_args.speculative_algorithm is None
|
||||||
|
), "Currently HybridAttentionBackend does not support speculative decoding."
|
||||||
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
||||||
|
HybridAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend = HybridAttnBackend(
|
||||||
|
decode_backend=self._get_attention_backend_from_str(
|
||||||
|
self.decode_attention_backend_str
|
||||||
|
),
|
||||||
|
prefill_backend=self._get_attention_backend_from_str(
|
||||||
|
self.prefill_attention_backend_str
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Using hybrid attention backend for decode and prefill: "
|
||||||
|
f"decode_backend={self.decode_attention_backend_str}, "
|
||||||
|
f"prefill_backend={self.prefill_attention_backend_str}."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
||||||
|
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_backend = self._get_attention_backend_from_str(
|
||||||
|
self.server_args.attention_backend
|
||||||
|
)
|
||||||
|
|
||||||
|
global_server_args_dict.update(
|
||||||
|
{
|
||||||
|
"decode_attention_backend": self.decode_attention_backend_str,
|
||||||
|
"prefill_attention_backend": self.prefill_attention_backend_str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
def _get_attention_backend_from_str(self, backend_str: str):
|
||||||
|
if backend_str == "flashinfer":
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
FlashInferAttnBackend,
|
FlashInferAttnBackend,
|
||||||
@@ -1318,6 +1367,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Init streams
|
# Init streams
|
||||||
if self.server_args.speculative_algorithm == "EAGLE":
|
if self.server_args.speculative_algorithm == "EAGLE":
|
||||||
|
if (
|
||||||
|
not hasattr(self, "plan_stream_for_flashinfer")
|
||||||
|
or not self.plan_stream_for_flashinfer
|
||||||
|
):
|
||||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||||
return FlashInferAttnBackend(self)
|
return FlashInferAttnBackend(self)
|
||||||
else:
|
else:
|
||||||
@@ -1326,15 +1379,15 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FlashInferMLAAttnBackend(self)
|
return FlashInferMLAAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "aiter":
|
elif backend_str == "aiter":
|
||||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||||
|
|
||||||
return AiterAttnBackend(self)
|
return AiterAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "ascend":
|
elif backend_str == "ascend":
|
||||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||||
|
|
||||||
return AscendAttnBackend(self)
|
return AscendAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "triton":
|
elif backend_str == "triton":
|
||||||
assert not self.model_config.is_encoder_decoder, (
|
assert not self.model_config.is_encoder_decoder, (
|
||||||
"Cross attention is not supported in the triton attention backend. "
|
"Cross attention is not supported in the triton attention backend. "
|
||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
@@ -1349,17 +1402,17 @@ class ModelRunner:
|
|||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
|
|
||||||
return TritonAttnBackend(self)
|
return TritonAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "torch_native":
|
elif backend_str == "torch_native":
|
||||||
from sglang.srt.layers.attention.torch_native_backend import (
|
from sglang.srt.layers.attention.torch_native_backend import (
|
||||||
TorchNativeAttnBackend,
|
TorchNativeAttnBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TorchNativeAttnBackend(self)
|
return TorchNativeAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "flashmla":
|
elif backend_str == "flashmla":
|
||||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||||
|
|
||||||
return FlashMLABackend(self)
|
return FlashMLABackend(self)
|
||||||
elif self.server_args.attention_backend == "fa3":
|
elif backend_str == "fa3":
|
||||||
assert (
|
assert (
|
||||||
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
||||||
) or torch.cuda.get_device_capability()[0] == 9, (
|
) or torch.cuda.get_device_capability()[0] == 9, (
|
||||||
@@ -1371,7 +1424,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FlashAttentionBackend(self)
|
return FlashAttentionBackend(self)
|
||||||
elif self.server_args.attention_backend == "cutlass_mla":
|
elif backend_str == "cutlass_mla":
|
||||||
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
||||||
CutlassMLABackend,
|
CutlassMLABackend,
|
||||||
)
|
)
|
||||||
@@ -1385,9 +1438,7 @@ class ModelRunner:
|
|||||||
logger.info(f"Intel AMX attention backend is enabled.")
|
logger.info(f"Intel AMX attention backend is enabled.")
|
||||||
return IntelAMXAttnBackend(self)
|
return IntelAMXAttnBackend(self)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
||||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_double_sparsity_channel_config(self, selected_channel):
|
def init_double_sparsity_channel_config(self, selected_channel):
|
||||||
selected_channel = "." + selected_channel + "_proj"
|
selected_channel = "." + selected_channel + "_proj"
|
||||||
@@ -1475,7 +1526,10 @@ class ModelRunner:
|
|||||||
if self.support_pp:
|
if self.support_pp:
|
||||||
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
|
forward_batch.input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
|
|||||||
@@ -925,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.disable_chunked_prefix_cache = global_server_args_dict[
|
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||||
"disable_chunked_prefix_cache"
|
"disable_chunked_prefix_cache"
|
||||||
]
|
]
|
||||||
self.attention_backend = global_server_args_dict["attention_backend"]
|
|
||||||
|
self.current_attention_backend = (
|
||||||
|
None # Attention backend used by current forward batch
|
||||||
|
)
|
||||||
self.rocm_fused_decode_mla = get_bool_env_var(
|
self.rocm_fused_decode_mla = get_bool_env_var(
|
||||||
"SGLANG_ROCM_FUSED_DECODE_MLA", "false"
|
"SGLANG_ROCM_FUSED_DECODE_MLA", "false"
|
||||||
)
|
)
|
||||||
@@ -1009,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
|
|
||||||
if self.attention_backend == "ascend":
|
# Determine attention backend used by current forward batch
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||||
|
else:
|
||||||
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||||
|
self.current_attention_backend = attention_backend
|
||||||
|
|
||||||
|
if attention_backend == "ascend":
|
||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
elif self.attention_backend == "flashinfer":
|
elif attention_backend == "flashinfer":
|
||||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||||
if (
|
if (
|
||||||
not self.flashinfer_mla_disable_ragged
|
not self.flashinfer_mla_disable_ragged
|
||||||
@@ -1023,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return AttnForwardMethod.MHA
|
return AttnForwardMethod.MHA
|
||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
elif self.attention_backend == "fa3":
|
elif attention_backend == "fa3":
|
||||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
||||||
if forward_batch.extend_prefix_lens_cpu is not None:
|
if forward_batch.extend_prefix_lens_cpu is not None:
|
||||||
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
||||||
@@ -1040,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
elif self.attention_backend == "aiter":
|
elif attention_backend == "aiter":
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
@@ -1288,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
self.attention_backend == "fa3"
|
self.current_attention_backend == "fa3"
|
||||||
or self.attention_backend == "flashinfer"
|
or self.current_attention_backend == "flashinfer"
|
||||||
or self.attention_backend == "cutlass_mla"
|
or self.current_attention_backend == "cutlass_mla"
|
||||||
):
|
):
|
||||||
attn_output = self.attn_mqa(
|
attn_output = self.attn_mqa(
|
||||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||||
|
|||||||
@@ -151,6 +151,8 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Kernel backend
|
# Kernel backend
|
||||||
attention_backend: Optional[str] = None
|
attention_backend: Optional[str] = None
|
||||||
|
decode_attention_backend: Optional[str] = None
|
||||||
|
prefill_attention_backend: Optional[str] = None
|
||||||
sampling_backend: Optional[str] = None
|
sampling_backend: Optional[str] = None
|
||||||
grammar_backend: Optional[str] = None
|
grammar_backend: Optional[str] = None
|
||||||
mm_attention_backend: Optional[str] = None
|
mm_attention_backend: Optional[str] = None
|
||||||
@@ -387,13 +389,19 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.page_size = 128
|
self.page_size = 128
|
||||||
|
|
||||||
if self.attention_backend == "flashmla":
|
if (
|
||||||
|
self.attention_backend == "flashmla"
|
||||||
|
or self.decode_attention_backend == "flashmla"
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
||||||
)
|
)
|
||||||
self.page_size = 64
|
self.page_size = 64
|
||||||
|
|
||||||
if self.attention_backend == "cutlass_mla":
|
if (
|
||||||
|
self.attention_backend == "cutlass_mla"
|
||||||
|
or self.decode_attention_backend == "cutlass_mla"
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
||||||
)
|
)
|
||||||
@@ -1213,6 +1221,35 @@ class ServerArgs:
|
|||||||
default=ServerArgs.attention_backend,
|
default=ServerArgs.attention_backend,
|
||||||
help="Choose the kernels for attention layers.",
|
help="Choose the kernels for attention layers.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-attention-backend",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"flashinfer",
|
||||||
|
"triton",
|
||||||
|
"torch_native",
|
||||||
|
"fa3",
|
||||||
|
"flashmla",
|
||||||
|
"cutlass_mla",
|
||||||
|
],
|
||||||
|
default=ServerArgs.decode_attention_backend,
|
||||||
|
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-attention-backend",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"flashinfer",
|
||||||
|
"triton",
|
||||||
|
"torch_native",
|
||||||
|
"fa3",
|
||||||
|
"flashmla",
|
||||||
|
"cutlass_mla",
|
||||||
|
],
|
||||||
|
default=ServerArgs.prefill_attention_backend,
|
||||||
|
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sampling-backend",
|
"--sampling-backend",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -491,6 +491,8 @@ class SRTRunner:
|
|||||||
lora_paths: List[str] = None,
|
lora_paths: List[str] = None,
|
||||||
max_loras_per_batch: int = 4,
|
max_loras_per_batch: int = 4,
|
||||||
attention_backend: Optional[str] = None,
|
attention_backend: Optional[str] = None,
|
||||||
|
prefill_attention_backend: Optional[str] = None,
|
||||||
|
decode_attention_backend: Optional[str] = None,
|
||||||
lora_backend: str = "triton",
|
lora_backend: str = "triton",
|
||||||
disable_cuda_graph: bool = False,
|
disable_cuda_graph: bool = False,
|
||||||
disable_radix_cache: bool = False,
|
disable_radix_cache: bool = False,
|
||||||
@@ -540,6 +542,8 @@ class SRTRunner:
|
|||||||
max_loras_per_batch=max_loras_per_batch,
|
max_loras_per_batch=max_loras_per_batch,
|
||||||
lora_backend=lora_backend,
|
lora_backend=lora_backend,
|
||||||
attention_backend=attention_backend,
|
attention_backend=attention_backend,
|
||||||
|
prefill_attention_backend=prefill_attention_backend,
|
||||||
|
decode_attention_backend=decode_attention_backend,
|
||||||
disable_cuda_graph=disable_cuda_graph,
|
disable_cuda_graph=disable_cuda_graph,
|
||||||
disable_radix_cache=disable_radix_cache,
|
disable_radix_cache=disable_radix_cache,
|
||||||
chunked_prefill_size=chunked_prefill_size,
|
chunked_prefill_size=chunked_prefill_size,
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ suites = {
|
|||||||
TestFile("test_vision_openai_server_b.py", 620),
|
TestFile("test_vision_openai_server_b.py", 620),
|
||||||
TestFile("test_w8a8_quantization.py", 46),
|
TestFile("test_w8a8_quantization.py", 46),
|
||||||
TestFile("test_reasoning_parser.py", 5),
|
TestFile("test_reasoning_parser.py", 5),
|
||||||
|
TestFile("test_hybrid_attn_backend.py", 100),
|
||||||
],
|
],
|
||||||
"per-commit-amd": [
|
"per-commit-amd": [
|
||||||
TestFile("models/lora/test_lora_backend.py", 99),
|
TestFile("models/lora/test_lora_backend.py", 99),
|
||||||
|
|||||||
109
test/srt/test_hybrid_attn_backend.py
Normal file
109
test/srt/test_hybrid_attn_backend.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import get_device_sm, kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
GSM_DATASET_PATH = None
|
||||||
|
|
||||||
|
# Default server arguments shared across all tests
|
||||||
|
DEFAULT_SERVER_ARGS = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"8",
|
||||||
|
"--prefill-attention-backend",
|
||||||
|
"fa3",
|
||||||
|
"--decode-attention-backend",
|
||||||
|
"flashinfer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
|
||||||
|
class TestHybridAttnBackendBase(CustomTestCase):
|
||||||
|
|
||||||
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
accuracy_threshold = 0.65 # derived tests need to override this
|
||||||
|
speculative_decode = False
|
||||||
|
spec_decode_threshold = 1.0 # derived spec decoding tests need to override this
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
"""Return the arguments for the server launch. Override in subclasses."""
|
||||||
|
return DEFAULT_SERVER_ARGS
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
# disable deep gemm precompile to make launch server faster
|
||||||
|
# please don't do this if you want to make your inference workload faster
|
||||||
|
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
|
||||||
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=cls.get_server_args(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=4,
|
||||||
|
num_questions=100,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
data_path=GSM_DATASET_PATH,
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
|
||||||
|
# Use the appropriate metric key based on the test class
|
||||||
|
metric_key = "accuracy"
|
||||||
|
self.assertGreater(metrics[metric_key], self.accuracy_threshold)
|
||||||
|
|
||||||
|
if self.speculative_decode:
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridAttnBackendMLA(TestHybridAttnBackendBase):
|
||||||
|
accuracy_threshold = 0.60
|
||||||
|
model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
|
||||||
|
accuracy_threshold = 0.65
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user