[Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it?
moe multistream overlap to improve the performance.
### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -20,9 +20,10 @@ _OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_EMBED_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
# flashcomm2 specific groups
|
||||
# flashcomm specific groups
|
||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||
_FC3_QUANT_X: Optional[GroupCoordinator] = None
|
||||
|
||||
# shared_weight across rank groups
|
||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
||||
@@ -241,6 +242,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
||||
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
|
||||
|
||||
if get_ascend_config().multistream_overlap_gate:
|
||||
global _FC3_QUANT_X
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_FC3_QUANT_X = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="fc3_quant_x")
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
@@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator:
|
||||
return _P_TP
|
||||
|
||||
|
||||
def get_fc3_quant_x_group() -> GroupCoordinator:
|
||||
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized")
|
||||
return _FC3_QUANT_X
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
@@ -343,3 +358,8 @@ def destroy_ascend_model_parallel():
|
||||
if _SHARED_WEIGHT:
|
||||
_SHARED_WEIGHT.destroy()
|
||||
_SHARED_WEIGHT = None
|
||||
|
||||
global _FC3_QUANT_X
|
||||
if _FC3_QUANT_X:
|
||||
_FC3_QUANT_X.destroy()
|
||||
_FC3_QUANT_X = None
|
||||
|
||||
Reference in New Issue
Block a user