[main][Bugfix] Fixed an problem related to embeddings sharing (#5967)

### What this PR does / why we need it?

Cancel the embeddings sharing when the embeddings of main model and the
embeddings of eagle model are different.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

Cause i don't have `Meta-Llama-3.1-8B-Instruc`t locally, i commented it
and run:

```shell
pytest -s tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_llama_qwen_eagle_acceptance
```

The output is fine:

```text
.

======================================================================================================================== warnings summary =========================================================================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================================== 3 passed, 1 skipped, 2 warnings in 196.19s (0:03:16) =======================================================================================================

```

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-20 21:34:28 +08:00
committed by GitHub
parent 6c30f8bf87
commit b2475099a0
2 changed files with 15 additions and 17 deletions

View File

@@ -156,14 +156,19 @@ class TestEagleProposerLoadModel(TestBase):
"layer3": mock_draft_layer3
}]
weight = torch.zeros(0)
mock_model = MagicMock()
mock_model.supports_multimodal = False
mock_model.model.embed_tokens = MagicMock()
mock_model.lm_head = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock()
mock_model.model.embed_tokens = MagicMock()
mock_model.model.embed_tokens.weight = weight
self.proposer.name = SpecDcodeType.EAGLE
mock_get_model.return_value = MagicMock()
mock_get_model.return_value.model.embed_tokens.weight = weight
self.proposer.load_model(mock_model)
mock_get_model.assert_called_once()

View File

@@ -177,28 +177,21 @@ class EagleProposer(VllmEagleProposer):
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
if self.method == "mtp":
if self.vllm_config.model_config.is_deepseek_mla and \
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \
torch.equal(self.model.model.embed_tokens.weight,
model.model.embed_tokens.weight):
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
logger.info(
"The MTP head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
" The MTP head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
" The EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \