[Feat]enable sfa cp for dsv3.2 (#4702)
### What this PR does / why we need it?
RFC: https://github.com/vllm-project/vllm/issues/30055
### How was this patch tested?
1. enable flashcommon1
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
2. enable sfa-cp
--additional-config '{ "enable_sfa_cp": true }' \
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: AlvisGong <gwly0401@163.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: hwhaokun <haokun0405@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import flashcomm2_enable
|
||||
from vllm_ascend.utils import enable_sp, flashcomm2_enable
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
@@ -19,6 +19,7 @@ _LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
@@ -48,6 +49,13 @@ def get_flashcomm2_odp_group() -> GroupCoordinator:
|
||||
return _FLASHCOMM2_ODP
|
||||
|
||||
|
||||
def get_shared_weight_group() -> GroupCoordinator:
|
||||
assert _SHARED_WEIGHT is not None, (
|
||||
"output shared weight parallel group for flashcomm2 is not initialized"
|
||||
)
|
||||
return _SHARED_WEIGHT
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
@@ -226,6 +234,18 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="flashcomm2_odp")
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if enable_sp() and is_ds_v32:
|
||||
global _SHARED_WEIGHT
|
||||
group_ranks = [list(range(torch.distributed.get_world_size()))]
|
||||
_SHARED_WEIGHT = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="CP_shared_weight")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
@@ -274,3 +294,8 @@ def destroy_ascend_model_parallel():
|
||||
).flashcomm2_oproj_tensor_parallel_size != 1:
|
||||
_FLASHCOMM2_ODP.destroy()
|
||||
_FLASHCOMM2_ODP = None
|
||||
|
||||
global _SHARED_WEIGHT
|
||||
if _SHARED_WEIGHT:
|
||||
_SHARED_WEIGHT.destroy()
|
||||
_SHARED_WEIGHT = None
|
||||
|
||||
Reference in New Issue
Block a user