ROCm: Flex Attention Enablement with custom backends (#4178)
Co-authored-by: linsun12 <linsun12@amd.com>
This commit is contained in:
@@ -79,6 +79,12 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -641,7 +647,7 @@ class ModelRunner:
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||
if is_hip(): # Using natively supported format
|
||||
if is_hip_: # Using natively supported format
|
||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
@@ -778,33 +784,59 @@ class ModelRunner:
|
||||
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
if is_cuda():
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
elif is_hip_:
|
||||
# AMD hip supported attention backends
|
||||
if self.server_args.attention_backend == "aiter":
|
||||
self.attn_backend = AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "aiter_decode":
|
||||
self.attn_backend = AiterDecodeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
|
||||
def init_double_sparsity_channel_config(self, selected_channel):
|
||||
selected_channel = "." + selected_channel + "_proj"
|
||||
|
||||
Reference in New Issue
Block a user