[Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it?
moe multistream overlap to improve the performance.
### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -20,9 +20,10 @@ _OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_EMBED_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
# flashcomm2 specific groups
|
||||
# flashcomm specific groups
|
||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||
_FC3_QUANT_X: Optional[GroupCoordinator] = None
|
||||
|
||||
# shared_weight across rank groups
|
||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
||||
@@ -241,6 +242,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
||||
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
|
||||
|
||||
if get_ascend_config().multistream_overlap_gate:
|
||||
global _FC3_QUANT_X
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_FC3_QUANT_X = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="fc3_quant_x")
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
@@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator:
|
||||
return _P_TP
|
||||
|
||||
|
||||
def get_fc3_quant_x_group() -> GroupCoordinator:
|
||||
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized")
|
||||
return _FC3_QUANT_X
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
@@ -343,3 +358,8 @@ def destroy_ascend_model_parallel():
|
||||
if _SHARED_WEIGHT:
|
||||
_SHARED_WEIGHT.destroy()
|
||||
_SHARED_WEIGHT = None
|
||||
|
||||
global _FC3_QUANT_X
|
||||
if _FC3_QUANT_X:
|
||||
_FC3_QUANT_X.destroy()
|
||||
_FC3_QUANT_X = None
|
||||
|
||||
@@ -2,8 +2,11 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||||
from vllm_ascend.distributed.parallel_state import (get_dp_group,
|
||||
get_fc3_quant_x_group,
|
||||
get_p_tp_group)
|
||||
|
||||
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
@@ -59,3 +62,31 @@ def get_transfer_timeout_value():
|
||||
'7')) # type: ignore
|
||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
||||
3000)
|
||||
|
||||
|
||||
def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return x
|
||||
x = get_fc3_quant_x_group().all_gather(x, 0)
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size]
|
||||
else:
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dp_size = get_dp_group().world_size
|
||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user