diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index ce932fbb..1709d150 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -182,9 +182,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): global _SHARED_WEIGHT # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") - if enable_sp() and is_ds_v32: + if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None: _SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight") - # TODO: Extract and unify the logic across different communication group. if flashcomm2_enable(): flashcomm2_otp_size = get_ascend_config( @@ -240,7 +239,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): # Create shared weight group for flashcomm2 oproj if flashcomm2_o_shared_enabled(): assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1" - _SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared") + if _SHARED_WEIGHT is None: + _SHARED_WEIGHT = _create_shared_weight_group( + "flashcomm2_o_shared") if get_ascend_config().multistream_overlap_gate: global _FC3_QUANT_X diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 28882924..e6f856b5 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from vllm.config import (CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) from vllm.distributed import get_pcp_group +from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -205,9 +206,11 @@ class MtpProposer(Proposer): if self.vllm_config.model_config.is_deepseek_mla: # check if mtp model use main model's embedding and LMhead main_model = model - if torch.equal(self.model.model.embed_tokens.weight, - main_model.model.embed_tokens.weight): - self.model.model.embed_tokens = main_model.model.embed_tokens + if get_pp_group().world_size == 1: + # If pp>1, the weights of mtp and the main model's embedding are not on the same device. + if torch.equal(self.model.model.embed_tokens.weight, + main_model.model.embed_tokens.weight): + self.model.model.embed_tokens = main_model.model.embed_tokens for _, layer_module in self.model.model.layers.items(): if torch.equal(layer_module.shared_head.head.weight, main_model.lm_head.weight):