ROCm: Flex Attention Enablement with custom backends (#4178)

Co-authored-by: linsun12 <linsun12@amd.com>
This commit is contained in:
HAI
2025-03-07 04:38:53 -08:00
committed by GitHub
parent c827c671f7
commit 0beea4503f
7 changed files with 1434 additions and 35 deletions

View File

@@ -79,6 +79,12 @@ from sglang.srt.utils import (
)
from sglang.utils import get_exception_traceback
is_hip_ = is_hip()
if is_hip_:
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend
logger = logging.getLogger(__name__)
@@ -641,7 +647,7 @@ class ModelRunner:
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if is_hip(): # Using natively supported format
if is_hip_: # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz
else:
self.kv_cache_dtype = torch.float8_e5m2
@@ -778,33 +784,59 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
if is_cuda():
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
elif is_hip_:
# AMD hip supported attention backends
if self.server_args.attention_backend == "aiter":
self.attn_backend = AiterAttnBackend(self)
elif self.server_args.attention_backend == "aiter_decode":
self.attn_backend = AiterDecodeAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"