diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 19f365e0..b6b96a71 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -156,14 +156,19 @@ class TestEagleProposerLoadModel(TestBase): "layer3": mock_draft_layer3 }] + weight = torch.zeros(0) + mock_model = MagicMock() mock_model.supports_multimodal = False - mock_model.model.embed_tokens = MagicMock() mock_model.lm_head = MagicMock() mock_model.multimodal_cpu_fields = None mock_model.merge_by_field_config = None - mock_get_model.return_value = MagicMock() + mock_model.model.embed_tokens = MagicMock() + mock_model.model.embed_tokens.weight = weight + self.proposer.name = SpecDcodeType.EAGLE + mock_get_model.return_value = MagicMock() + mock_get_model.return_value.model.embed_tokens.weight = weight self.proposer.load_model(mock_model) mock_get_model.assert_called_once() diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b8b9094b..d5d4afaf 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -177,28 +177,21 @@ class EagleProposer(VllmEagleProposer): raise AttributeError( "Target model does not have 'embed_tokens' or 'embedding' attribute" ) - if self.method == "mtp": - if self.vllm_config.model_config.is_deepseek_mla and \ + # If pp>1, the weights of mtp and the main model's embedding are not on the same device. + # check if mtp model use main model's embedding and LMhead + if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \ torch.equal(self.model.model.embed_tokens.weight, model.model.embed_tokens.weight): - # If pp>1, the weights of mtp and the main model's embedding are not on the same device. - # check if mtp model use main model's embedding and LMhead - logger.info( - "The MTP head shares the same vocab embedding" \ - " with the target model." - ) - self.model.model.embed_tokens = target_embed_tokens - else: - logger.info( - " The MTP head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." - ) - else: logger.info( "The EAGLE head shares the same vocab embedding" \ " with the target model." ) self.model.model.embed_tokens = target_embed_tokens + else: + logger.info( + " The EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) else: logger.info( "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \