From b2475099a02d55d8a591c99e7c00b8958096ec94 Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Tue, 20 Jan 2026 21:34:28 +0800 Subject: [PATCH] [main][Bugfix] Fixed an problem related to embeddings sharing (#5967) ### What this PR does / why we need it? Cancel the embeddings sharing when the embeddings of main model and the embeddings of eagle model are different. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? Cause i don't have `Meta-Llama-3.1-8B-Instruc`t locally, i commented it and run: ```shell pytest -s tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_llama_qwen_eagle_acceptance ``` The output is fine: ```text . ======================================================================================================================== warnings summary ========================================================================================================================= :241 :241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute :241 :241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================================================================== 3 passed, 1 skipped, 2 warnings in 196.19s (0:03:16) ======================================================================================================= ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 Signed-off-by: drslark --- tests/ut/spec_decode/test_eagle_proposer.py | 9 ++++++-- vllm_ascend/spec_decode/eagle_proposer.py | 23 +++++++-------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 19f365e0..b6b96a71 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -156,14 +156,19 @@ class TestEagleProposerLoadModel(TestBase): "layer3": mock_draft_layer3 }] + weight = torch.zeros(0) + mock_model = MagicMock() mock_model.supports_multimodal = False - mock_model.model.embed_tokens = MagicMock() mock_model.lm_head = MagicMock() mock_model.multimodal_cpu_fields = None mock_model.merge_by_field_config = None - mock_get_model.return_value = MagicMock() + mock_model.model.embed_tokens = MagicMock() + mock_model.model.embed_tokens.weight = weight + self.proposer.name = SpecDcodeType.EAGLE + mock_get_model.return_value = MagicMock() + mock_get_model.return_value.model.embed_tokens.weight = weight self.proposer.load_model(mock_model) mock_get_model.assert_called_once() diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b8b9094b..d5d4afaf 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -177,28 +177,21 @@ class EagleProposer(VllmEagleProposer): raise AttributeError( "Target model does not have 'embed_tokens' or 'embedding' attribute" ) - if self.method == "mtp": - if self.vllm_config.model_config.is_deepseek_mla and \ + # If pp>1, the weights of mtp and the main model's embedding are not on the same device. + # check if mtp model use main model's embedding and LMhead + if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \ torch.equal(self.model.model.embed_tokens.weight, model.model.embed_tokens.weight): - # If pp>1, the weights of mtp and the main model's embedding are not on the same device. - # check if mtp model use main model's embedding and LMhead - logger.info( - "The MTP head shares the same vocab embedding" \ - " with the target model." - ) - self.model.model.embed_tokens = target_embed_tokens - else: - logger.info( - " The MTP head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." - ) - else: logger.info( "The EAGLE head shares the same vocab embedding" \ " with the target model." ) self.model.model.embed_tokens = target_embed_tokens + else: + logger.info( + " The EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) else: logger.info( "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \