[Perf] Supports compute-communication overlap in the forward of sfa_v1 in the Sharded-CP feature. (#5701)
### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
### All-gather KV Cache for Communication Overlap:
- This PR adjusts the calculation order in the SFA.
- split `index_select` into `indexer_select_pre_process` and
`indexer_select_post_process`.
- Combine `nope`, `rope` and `index-k` into a tensor to perform
asynchronous all-gather.
### benchmark:
input=40k && num_batch_token=20k
- before:
```
Mean TTFT (ms): 2614.52
Median TTFT (ms): 3148.03
P50 TTFT (ms): 3148.03
P90 TTFT (ms): 3163.48
P99 TTFT (ms): 3170.20
```
- after:
```
Mean TTFT (ms): 2529.92
Median TTFT (ms): 3051.69
P50 TTFT (ms): 3051.69
P90 TTFT (ms): 3067.31
P99 TTFT (ms): 3072.15
```
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
|
||||
@@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
return x
|
||||
|
||||
|
||||
def all_gather_async(input: torch.Tensor,
|
||||
group: GroupCoordinator,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
async_op: bool = True):
|
||||
if group.world_size == 1:
|
||||
return input, None
|
||||
if output is None:
|
||||
input_size = input.size()
|
||||
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
|
||||
output = torch.empty(output_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device)
|
||||
return output, dist.all_gather_into_tensor(output,
|
||||
input,
|
||||
group=group.device_group,
|
||||
async_op=async_op)
|
||||
Reference in New Issue
Block a user