[Perf]enable prefill flashcommon3 (#4065)

### What this PR does / why we need it?
moe multistream overlap to improve the performance.

### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
AlvisGong
2025-12-14 09:34:13 +08:00
committed by GitHub
parent 0686b32d82
commit ba28d54f35
8 changed files with 239 additions and 40 deletions

View File

@@ -20,9 +20,10 @@ _OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_EMBED_TP: Optional[GroupCoordinator] = None
# flashcomm2 specific groups
# flashcomm specific groups
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
_FC3_QUANT_X: Optional[GroupCoordinator] = None
# shared_weight across rank groups
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
@@ -241,6 +242,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_FC3_QUANT_X = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="fc3_quant_x")
def model_parallel_initialized():
return (_MC2 is not None)
@@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator:
return _P_TP
def get_fc3_quant_x_group() -> GroupCoordinator:
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized")
return _FC3_QUANT_X
def destroy_ascend_model_parallel():
global _MC2
if _MC2:
@@ -343,3 +358,8 @@ def destroy_ascend_model_parallel():
if _SHARED_WEIGHT:
_SHARED_WEIGHT.destroy()
_SHARED_WEIGHT = None
global _FC3_QUANT_X
if _FC3_QUANT_X:
_FC3_QUANT_X.destroy()
_FC3_QUANT_X = None