AITER backend extension and workload optimizations (#6838)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
@@ -105,6 +105,7 @@ from sglang.srt.utils import (
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
@@ -120,6 +121,9 @@ if _is_hip:
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
from aiter.rotary_embedding import get_rope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
|
||||
self.alt_stream = alt_stream
|
||||
self.attn_mha.kv_b_proj = None
|
||||
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif self.attention_backend == "aiter":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
if self.attn_mha.kv_b_proj is None:
|
||||
self.attn_mha.kv_b_proj = self.kv_b_proj
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
|
||||
Reference in New Issue
Block a user