Fix accuracy drop of dsv3 run in dp enablement (#8677)
Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
@@ -18,7 +18,10 @@ import triton.language as tl
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_size,
|
||||||
|
is_dp_attention_enabled,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.enable_dp_attention = is_dp_attention_enabled()
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
self.mla_indices_updater_prefill.update(
|
self.mla_indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.extend_prefix_lens,
|
forward_batch.seq_lens,
|
||||||
sum(forward_batch.extend_prefix_lens_cpu),
|
forward_batch.seq_lens_sum,
|
||||||
forward_batch.extend_seq_lens,
|
forward_batch.extend_seq_lens,
|
||||||
max(forward_batch.extend_seq_lens_cpu),
|
forward_batch.extend_seq_lens.max().item(),
|
||||||
forward_batch.seq_lens_cpu.max().item(),
|
forward_batch.seq_lens.max().item(),
|
||||||
spec_info=None,
|
spec_info=None,
|
||||||
)
|
)
|
||||||
self.mla_indices_updater_prefill.kv_indptr += (
|
|
||||||
self.mla_indices_updater_prefill.qo_indptr
|
kv_indices = self.mla_indices_updater_prefill.kv_indices
|
||||||
)
|
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.mla_indices_updater_prefill.kv_indptr,
|
self.mla_indices_updater_prefill.kv_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_indices,
|
kv_indices,
|
||||||
self.mla_indices_updater_prefill.qo_indptr,
|
self.mla_indices_updater_prefill.qo_indptr,
|
||||||
self.kv_last_page_len[:bs],
|
self.kv_last_page_len[:bs],
|
||||||
self.mla_indices_updater_prefill.max_q_len,
|
self.mla_indices_updater_prefill.max_q_len,
|
||||||
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
assert len(k.shape) == 3
|
assert len(k.shape) == 3
|
||||||
assert len(v.shape) == 3
|
assert len(v.shape) == 3
|
||||||
|
|
||||||
if kv_indices.shape[0] == 0:
|
if forward_batch.forward_mode.is_extend():
|
||||||
o = flash_attn_varlen_func(
|
if kv_indices.shape[0] == 0:
|
||||||
q,
|
o = flash_attn_varlen_func(
|
||||||
k,
|
q,
|
||||||
v,
|
k,
|
||||||
qo_indptr,
|
v,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
max_q_len,
|
qo_indptr,
|
||||||
max_q_len,
|
max_q_len,
|
||||||
softmax_scale=layer.scaling,
|
max_q_len,
|
||||||
causal=True,
|
softmax_scale=layer.scaling,
|
||||||
)
|
causal=True,
|
||||||
return o
|
)
|
||||||
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
return o
|
||||||
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
||||||
kvc, k_pe = torch.split(
|
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
||||||
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
kvc, k_pe = torch.split(
|
||||||
)
|
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
||||||
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
)
|
||||||
|
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
||||||
|
|
||||||
kvprefix = kvprefix.view(
|
kvprefix = kvprefix.view(
|
||||||
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
||||||
)
|
)
|
||||||
k_prefix, v_prefix = torch.split(
|
k_prefix, v_prefix = torch.split(
|
||||||
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
||||||
)
|
)
|
||||||
k_prefix = torch.cat(
|
k_prefix = torch.cat(
|
||||||
[
|
[
|
||||||
k_prefix,
|
k_prefix,
|
||||||
torch.broadcast_to(
|
torch.broadcast_to(
|
||||||
k_pe,
|
k_pe,
|
||||||
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
forward_batch.extend_prefix_lens.shape
|
forward_batch.extend_prefix_lens.shape
|
||||||
== forward_batch.extend_seq_lens.shape
|
== forward_batch.extend_seq_lens.shape
|
||||||
)
|
)
|
||||||
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
|
|
||||||
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
|
|
||||||
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
|
|
||||||
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
|
|
||||||
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
|
|
||||||
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
|
|
||||||
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
|
|
||||||
|
|
||||||
o = flash_attn_varlen_func(
|
k = k_prefix
|
||||||
q,
|
v = v_prefix
|
||||||
k,
|
|
||||||
v,
|
o = flash_attn_varlen_func(
|
||||||
qo_indptr,
|
q,
|
||||||
kv_indptr,
|
k,
|
||||||
max_q_len,
|
v,
|
||||||
max_kv_len,
|
qo_indptr,
|
||||||
softmax_scale=layer.scaling,
|
kv_indptr,
|
||||||
causal=True,
|
max_q_len,
|
||||||
)
|
max_kv_len,
|
||||||
return o
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
else:
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty(
|
||||||
|
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
mla_prefill_fwd(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.forward_metadata.kv_last_page_len,
|
||||||
|
self.forward_metadata.max_q_len,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
|
||||||
|
return o
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
||||||
mla_decode_fwd(
|
mla_decode_fwd(
|
||||||
|
|||||||
@@ -1085,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
):
|
):
|
||||||
return AttnForwardMethod.MHA
|
if is_dp_attention_enabled():
|
||||||
|
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
else:
|
else:
|
||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user