[Perf] Supports compute-communication overlap in the forward of sfa_v1 in the Sharded-CP feature. (#5701)

### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

### All-gather KV Cache for Communication Overlap:
- This PR adjusts the calculation order in the SFA.
- split `index_select` into `indexer_select_pre_process` and
`indexer_select_post_process`.
- Combine `nope`, `rope` and `index-k` into a tensor to perform
asynchronous all-gather.

### benchmark:
input=40k && num_batch_token=20k
- before:
```
Mean TTFT (ms):                          2614.52
Median TTFT (ms):                        3148.03
P50 TTFT (ms):                           3148.03
P90 TTFT (ms):                           3163.48
P99 TTFT (ms):                           3170.20
```

- after:
```
Mean TTFT (ms):                          2529.92
Median TTFT (ms):                        3051.69
P50 TTFT (ms):                           3051.69
P90 TTFT (ms):                           3067.31
P99 TTFT (ms):                           3072.15
```

### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
zzhxxx
2026-01-11 09:47:27 +08:00
committed by GitHub
parent c5744e2350
commit db12c1e2c8
3 changed files with 124 additions and 61 deletions

View File

@@ -24,6 +24,7 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.distributed.utils import all_gather_async
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
@@ -227,7 +228,10 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
fill_value=-1,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
assert cos.shape[0] == num_tokens_per_device, \
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
@@ -505,7 +509,6 @@ class AscendSFAImpl(MLAAttentionImpl):
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
slots_cp: Optional[torch.Tensor],
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
@@ -516,30 +519,19 @@ class AscendSFAImpl(MLAAttentionImpl):
cache_mode = "PA"
if self.enable_sfa_cp:
assert slots_cp is not None
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots_cp.to(torch.int64),
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv=True,
)
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
k_pe = get_tp_group().all_gather(k_pe, 0)
k_nope = get_tp_group().all_gather(k_nope, 0)
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(
kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1),
k_nope.view(-1, k_nope.shape[-1]))
torch_npu.npu_scatter_nd_update_(
kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1),
k_pe.view(-1, k_pe.shape[-1]))
return k_pe, k_nope
else:
torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
@@ -552,6 +544,7 @@ class AscendSFAImpl(MLAAttentionImpl):
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return None, None
def rope_single(
self,
@@ -744,6 +737,7 @@ class AscendSFAImpl(MLAAttentionImpl):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos
sin = attn_metadata.sin
@@ -764,6 +758,12 @@ class AscendSFAImpl(MLAAttentionImpl):
need_gather_q_kv=need_gather_q_kv,
num_input_tokens=attn_metadata.num_input_tokens,
)
q, k = self.indexer_select_pre_process(
x=hidden_states,
qr=q_c,
cos=cos,
sin=sin,
need_gather_q_kv=need_gather_q_kv)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
@@ -782,31 +782,67 @@ class AscendSFAImpl(MLAAttentionImpl):
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
q, k = self.indexer_select_pre_process(
x=hidden_states,
qr=q_c,
cos=cos,
sin=sin,
need_gather_q_kv=need_gather_q_kv)
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
slot_mapping = attn_metadata.slot_mapping
slot_mapping_cp = None
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
slot_mapping_cp)
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
slot_mapping)
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
if self.enable_sfa_cp:
assert k_pe is not None
assert k_nope is not None
# support all_gather kv async for communication calculation overlap
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat([
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k.view(-1, k.shape[-1])
],
dim=1), get_tp_group())
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
topk_indices = self.indexer_select(
if self.enable_sfa_cp:
if kv_ag_handle is not None:
kv_ag_handle.wait()
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
if kv_cache is not None:
assert fused_kv_no_split is not None
k_pe, k_nope, k = fused_kv_no_split.split([
self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim
],
dim=-1)
slot_mapping = attn_metadata.slot_mapping.view(-1, 1)
torch_npu.npu_scatter_nd_update_(
kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping,
k_nope)
torch_npu.npu_scatter_nd_update_(
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
k_pe)
topk_indices = self.indexer_select_post_process(
x=hidden_states,
qr=q_c,
q=q,
k=k,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
cos=cos,
@@ -814,6 +850,7 @@ class AscendSFAImpl(MLAAttentionImpl):
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
need_gather_q_kv=need_gather_q_kv)
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
key=kv_cache[0],
@@ -830,6 +867,7 @@ class AscendSFAImpl(MLAAttentionImpl):
layout_kv="PA_BSND",
sparse_mode=3,
)
attn_output = self._v_up_proj(attn_output, has_prefill)
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=attn_output,
@@ -838,22 +876,14 @@ class AscendSFAImpl(MLAAttentionImpl):
output[...] = self.o_proj(attn_output)[0]
return output_padded
def indexer_select(
def indexer_select_pre_process(
self,
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
need_gather_q_kv: bool = False,
):
# q process in new stream
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
@@ -861,6 +891,9 @@ class AscendSFAImpl(MLAAttentionImpl):
k = k.view(-1, 1, self.head_dim)
if HAS_TRITON:
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q,
@@ -870,6 +903,38 @@ class AscendSFAImpl(MLAAttentionImpl):
rope_dim=self.qk_rope_head_dim,
is_neox_style=True)
else:
k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
q = None
return q, k
def indexer_select_post_process(
self,
x: torch.Tensor,
qr: torch.Tensor,
q: Optional[torch.Tensor],
k: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
need_gather_q_kv: bool = False,
):
if q is None:
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
@@ -884,20 +949,6 @@ class AscendSFAImpl(MLAAttentionImpl):
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
if self.enable_sfa_cp:
k = get_tp_group().all_gather(k, 0)
if kv_cache is not None:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()

View File

@@ -1,8 +1,9 @@
import os
from typing import Optional
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
@@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
offset += num_tokens_dp
x = result
return x
def all_gather_async(input: torch.Tensor,
group: GroupCoordinator,
output: Optional[torch.Tensor] = None,
async_op: bool = True):
if group.world_size == 1:
return input, None
if output is None:
input_size = input.size()
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
output = torch.empty(output_size,
dtype=input.dtype,
device=input.device)
return output, dist.all_gather_into_tensor(output,
input,
group=group.device_group,
async_op=async_op)

View File

@@ -1172,20 +1172,13 @@ def singleton(cls):
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config is None:
return False
model_config = getattr(vllm_config, "model_config", None)
if model_config is None:
return False
hf_text_config = getattr(model_config, "hf_text_config", None)
if hf_text_config is None:
return False
return hasattr(hf_text_config, "index_topk")
is_ds_v32 = hasattr(
vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk")
if is_ds_v32 and enable_sp():
return True
return False
@lru_cache(maxsize=1)