Revert "ROCm: Flex Attention Enablement with custom backends (#4178)" (#4186)

This commit is contained in:
Yineng Zhang
2025-03-07 10:27:52 -08:00
committed by GitHub
parent 0beea4503f
commit eb61f5c9af
7 changed files with 35 additions and 1434 deletions

View File

@@ -79,12 +79,6 @@ from sglang.srt.utils import (
)
from sglang.utils import get_exception_traceback
is_hip_ = is_hip()
if is_hip_:
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend
logger = logging.getLogger(__name__)
@@ -647,7 +641,7 @@ class ModelRunner:
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if is_hip_: # Using natively supported format
if is_hip(): # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz
else:
self.kv_cache_dtype = torch.float8_e5m2
@@ -784,59 +778,33 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if is_cuda():
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
elif is_hip_:
# AMD hip supported attention backends
if self.server_args.attention_backend == "aiter":
self.attn_backend = AiterAttnBackend(self)
elif self.server_args.attention_backend == "aiter_decode":
self.attn_backend = AiterDecodeAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"