AITER backend extension and workload optimizations (#6838)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
HAI
2025-06-05 23:00:18 -07:00
committed by GitHub
parent 562f279a2d
commit b819381fec
12 changed files with 583 additions and 164 deletions

View File

@@ -105,6 +105,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
@@ -120,6 +121,9 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
if _use_aiter:
from aiter.rotary_embedding import get_rope
logger = logging.getLogger(__name__)
@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.alt_stream = alt_stream
self.attn_mha.kv_b_proj = None
self.w_kc = None
self.w_vc = None
@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif self.attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results