[Perf] Supports compute-communication overlap in the forward of sfa_v1 in the Sharded-CP feature. (#5701)

### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

### All-gather KV Cache for Communication Overlap:
- This PR adjusts the calculation order in the SFA.
- split `index_select` into `indexer_select_pre_process` and
`indexer_select_post_process`.
- Combine `nope`, `rope` and `index-k` into a tensor to perform
asynchronous all-gather.

### benchmark:
input=40k && num_batch_token=20k
- before:
```
Mean TTFT (ms):                          2614.52
Median TTFT (ms):                        3148.03
P50 TTFT (ms):                           3148.03
P90 TTFT (ms):                           3163.48
P99 TTFT (ms):                           3170.20
```

- after:
```
Mean TTFT (ms):                          2529.92
Median TTFT (ms):                        3051.69
P50 TTFT (ms):                           3051.69
P90 TTFT (ms):                           3067.31
P99 TTFT (ms):                           3072.15
```

### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
zzhxxx
2026-01-11 09:47:27 +08:00
committed by GitHub
parent c5744e2350
commit db12c1e2c8
3 changed files with 124 additions and 61 deletions

View File

@@ -1,8 +1,9 @@
import os
from typing import Optional
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
@@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
offset += num_tokens_dp
x = result
return x
def all_gather_async(input: torch.Tensor,
group: GroupCoordinator,
output: Optional[torch.Tensor] = None,
async_op: bool = True):
if group.world_size == 1:
return input, None
if output is None:
input_size = input.size()
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
output = torch.empty(output_size,
dtype=input.dtype,
device=input.device)
return output, dist.all_gather_into_tensor(output,
input,
group=group.device_group,
async_op=async_op)