[Bugfix] Fix the bug in initializing the shared_weight communication domain in sfa-cp, and fix the mtp weight load in pp>1 situation (#4913)

### What this PR does / why we need it?
In PR #4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
zzhxxx
2025-12-15 16:21:49 +08:00
committed by GitHub
parent 70606e0bb9
commit e16444f21f
2 changed files with 10 additions and 6 deletions

View File

@@ -182,9 +182,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
global _SHARED_WEIGHT
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
if enable_sp() and is_ds_v32:
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
@@ -240,7 +239,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# Create shared weight group for flashcomm2 oproj
if flashcomm2_o_shared_enabled():
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
if _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group(
"flashcomm2_o_shared")
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X