[Perf] Supports compute-communication overlap in the forward of sfa_v1 in the Sharded-CP feature. (#5701)
### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
### All-gather KV Cache for Communication Overlap:
- This PR adjusts the calculation order in the SFA.
- split `index_select` into `indexer_select_pre_process` and
`indexer_select_post_process`.
- Combine `nope`, `rope` and `index-k` into a tensor to perform
asynchronous all-gather.
### benchmark:
input=40k && num_batch_token=20k
- before:
```
Mean TTFT (ms): 2614.52
Median TTFT (ms): 3148.03
P50 TTFT (ms): 3148.03
P90 TTFT (ms): 3163.48
P99 TTFT (ms): 3170.20
```
- after:
```
Mean TTFT (ms): 2529.92
Median TTFT (ms): 3051.69
P50 TTFT (ms): 3051.69
P90 TTFT (ms): 3067.31
P99 TTFT (ms): 3072.15
```
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.distributed.utils import all_gather_async
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||
reach_layer_for_shard_weight_series,
|
||||
@@ -227,7 +228,10 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
|
||||
cos = cos[local_start:local_end_with_pad]
|
||||
sin = sin[local_start:local_end_with_pad]
|
||||
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
||||
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
|
||||
fill_value=-1,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
assert cos.shape[0] == num_tokens_per_device, \
|
||||
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
||||
got {cos.shape[0]} and {num_tokens_per_device}"
|
||||
@@ -505,7 +509,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
slots_cp: Optional[torch.Tensor],
|
||||
):
|
||||
B = kv_no_split.shape[0]
|
||||
N = self.num_kv_heads
|
||||
@@ -516,30 +519,19 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
cache_mode = "PA"
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
assert slots_cp is not None
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots_cp.to(torch.int64),
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
is_output_kv=True,
|
||||
)
|
||||
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
||||
k_pe = get_tp_group().all_gather(k_pe, 0)
|
||||
k_nope = get_tp_group().all_gather(k_nope, 0)
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1),
|
||||
k_nope.view(-1, k_nope.shape[-1]))
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1),
|
||||
k_pe.view(-1, k_pe.shape[-1]))
|
||||
return k_pe, k_nope
|
||||
else:
|
||||
torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
@@ -552,6 +544,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
return None, None
|
||||
|
||||
def rope_single(
|
||||
self,
|
||||
@@ -744,6 +737,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
@@ -764,6 +758,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
num_input_tokens=attn_metadata.num_input_tokens,
|
||||
)
|
||||
q, k = self.indexer_select_pre_process(
|
||||
x=hidden_states,
|
||||
qr=q_c,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
else:
|
||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
@@ -782,31 +782,67 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
q, k = self.indexer_select_pre_process(
|
||||
x=hidden_states,
|
||||
qr=q_c,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping_cp = None
|
||||
if self.enable_sfa_cp:
|
||||
assert attn_metadata.sfa_cp_context is not None
|
||||
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
|
||||
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
|
||||
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
||||
|
||||
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
|
||||
slot_mapping_cp)
|
||||
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
|
||||
slot_mapping)
|
||||
|
||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
if self.enable_sfa_cp:
|
||||
assert k_pe is not None
|
||||
assert k_nope is not None
|
||||
# support all_gather kv async for communication calculation overlap
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat([
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k.view(-1, k.shape[-1])
|
||||
],
|
||||
dim=1), get_tp_group())
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
|
||||
topk_indices = self.indexer_select(
|
||||
if self.enable_sfa_cp:
|
||||
if kv_ag_handle is not None:
|
||||
kv_ag_handle.wait()
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
|
||||
if kv_cache is not None:
|
||||
assert fused_kv_no_split is not None
|
||||
k_pe, k_nope, k = fused_kv_no_split.split([
|
||||
self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
slot_mapping = attn_metadata.slot_mapping.view(-1, 1)
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping,
|
||||
k_nope)
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
|
||||
k_pe)
|
||||
|
||||
topk_indices = self.indexer_select_post_process(
|
||||
x=hidden_states,
|
||||
qr=q_c,
|
||||
q=q,
|
||||
k=k,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
cos=cos,
|
||||
@@ -814,6 +850,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=kv_cache[0],
|
||||
@@ -830,6 +867,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
attn_output = self._v_up_proj(attn_output, has_prefill)
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=attn_output,
|
||||
@@ -838,22 +876,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
output[...] = self.o_proj(attn_output)[0]
|
||||
return output_padded
|
||||
|
||||
def indexer_select(
|
||||
def indexer_select_pre_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
actual_seq_lengths_query: torch.Tensor,
|
||||
actual_seq_lengths_key: torch.Tensor,
|
||||
need_gather_q_kv: bool = False,
|
||||
):
|
||||
# q process in new stream
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
|
||||
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
k_proj, need_gather_q_kv)
|
||||
@@ -861,6 +891,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
k = k.view(-1, 1, self.head_dim)
|
||||
|
||||
if HAS_TRITON:
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
|
||||
cos = cos.view(-1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, self.qk_rope_head_dim)
|
||||
q, k = rope_forward_triton(q,
|
||||
@@ -870,6 +903,38 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
is_neox_style=True)
|
||||
else:
|
||||
k_pe, k_nope = torch.split(
|
||||
k,
|
||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
q = None
|
||||
|
||||
return q, k
|
||||
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
q: Optional[torch.Tensor],
|
||||
k: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
actual_seq_lengths_query: torch.Tensor,
|
||||
actual_seq_lengths_key: torch.Tensor,
|
||||
need_gather_q_kv: bool = False,
|
||||
):
|
||||
if q is None:
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
@@ -884,20 +949,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
k_pe, k_nope = torch.split(
|
||||
k,
|
||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
k = get_tp_group().all_gather(k, 0)
|
||||
|
||||
if kv_cache is not None:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
|
||||
@@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
return x
|
||||
|
||||
|
||||
def all_gather_async(input: torch.Tensor,
|
||||
group: GroupCoordinator,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
async_op: bool = True):
|
||||
if group.world_size == 1:
|
||||
return input, None
|
||||
if output is None:
|
||||
input_size = input.size()
|
||||
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
|
||||
output = torch.empty(output_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device)
|
||||
return output, dist.all_gather_into_tensor(output,
|
||||
input,
|
||||
group=group.device_group,
|
||||
async_op=async_op)
|
||||
@@ -1172,20 +1172,13 @@ def singleton(cls):
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config is None:
|
||||
return False
|
||||
|
||||
model_config = getattr(vllm_config, "model_config", None)
|
||||
if model_config is None:
|
||||
return False
|
||||
|
||||
hf_text_config = getattr(model_config, "hf_text_config", None)
|
||||
if hf_text_config is None:
|
||||
return False
|
||||
|
||||
return hasattr(hf_text_config, "index_topk")
|
||||
is_ds_v32 = hasattr(
|
||||
vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if is_ds_v32 and enable_sp():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
Reference in New Issue
Block a user