[Bugfix] Fix the bug in initializing the shared_weight communication domain in sfa-cp, and fix the mtp weight load in pp>1 situation (#4913)
### What this PR does / why we need it?
In PR #4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -182,9 +182,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
global _SHARED_WEIGHT
|
global _SHARED_WEIGHT
|
||||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
||||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||||
if enable_sp() and is_ds_v32:
|
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
|
||||||
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
|
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
|
||||||
|
|
||||||
# TODO: Extract and unify the logic across different communication group.
|
# TODO: Extract and unify the logic across different communication group.
|
||||||
if flashcomm2_enable():
|
if flashcomm2_enable():
|
||||||
flashcomm2_otp_size = get_ascend_config(
|
flashcomm2_otp_size = get_ascend_config(
|
||||||
@@ -240,7 +239,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
# Create shared weight group for flashcomm2 oproj
|
# Create shared weight group for flashcomm2 oproj
|
||||||
if flashcomm2_o_shared_enabled():
|
if flashcomm2_o_shared_enabled():
|
||||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
||||||
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
|
if _SHARED_WEIGHT is None:
|
||||||
|
_SHARED_WEIGHT = _create_shared_weight_group(
|
||||||
|
"flashcomm2_o_shared")
|
||||||
|
|
||||||
if get_ascend_config().multistream_overlap_gate:
|
if get_ascend_config().multistream_overlap_gate:
|
||||||
global _FC3_QUANT_X
|
global _FC3_QUANT_X
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config, set_current_vllm_config)
|
get_layers_from_vllm_config, set_current_vllm_config)
|
||||||
from vllm.distributed import get_pcp_group
|
from vllm.distributed import get_pcp_group
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
@@ -205,9 +206,11 @@ class MtpProposer(Proposer):
|
|||||||
if self.vllm_config.model_config.is_deepseek_mla:
|
if self.vllm_config.model_config.is_deepseek_mla:
|
||||||
# check if mtp model use main model's embedding and LMhead
|
# check if mtp model use main model's embedding and LMhead
|
||||||
main_model = model
|
main_model = model
|
||||||
if torch.equal(self.model.model.embed_tokens.weight,
|
if get_pp_group().world_size == 1:
|
||||||
main_model.model.embed_tokens.weight):
|
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
|
||||||
self.model.model.embed_tokens = main_model.model.embed_tokens
|
if torch.equal(self.model.model.embed_tokens.weight,
|
||||||
|
main_model.model.embed_tokens.weight):
|
||||||
|
self.model.model.embed_tokens = main_model.model.embed_tokens
|
||||||
for _, layer_module in self.model.model.layers.items():
|
for _, layer_module in self.model.model.layers.items():
|
||||||
if torch.equal(layer_module.shared_head.head.weight,
|
if torch.equal(layer_module.shared_head.head.weight,
|
||||||
main_model.lm_head.weight):
|
main_model.lm_head.weight):
|
||||||
|
|||||||
Reference in New Issue
Block a user