[Bugfix] Fix the bug in initializing the shared_weight communication domain in sfa-cp, and fix the mtp weight load in pp>1 situation (#4913)

### What this PR does / why we need it?
In PR #4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
zzhxxx
2025-12-15 16:21:49 +08:00
committed by GitHub
parent 70606e0bb9
commit e16444f21f
2 changed files with 10 additions and 6 deletions

View File

@@ -8,6 +8,7 @@ import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.distributed import get_pcp_group
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -205,9 +206,11 @@ class MtpProposer(Proposer):
if self.vllm_config.model_config.is_deepseek_mla:
# check if mtp model use main model's embedding and LMhead
main_model = model
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
if get_pp_group().world_size == 1:
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):