diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 789dd091..fb2928e6 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -226,19 +226,49 @@ class EagleProposer(VllmEagleProposer): ) # If pp>1, the weights of mtp and the main model's embedding are not on the same device. # check if mtp model use main model's embedding and LMhead - if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \ - torch.equal(self.model.model.embed_tokens.weight, - model.model.embed_tokens.weight): - logger.info( - "The EAGLE head shares the same vocab embedding" \ - " with the target model." - ) - self.model.model.embed_tokens = target_embed_tokens + share_embeddings = False + if hasattr(self.model, "has_own_embed_tokens"): + # EAGLE model + if not self.model.has_own_embed_tokens: + share_embeddings = True + logger.info( + "Detected EAGLE model without its own embed_tokens in the" + " checkpoint. Sharing target model embedding weights with the" + " draft model." + ) + elif ( + isinstance(target_embed_tokens.weight, torch.Tensor) + and isinstance(self.model.model.embed_tokens.weight, torch.Tensor) + # TODO: Offload to CPU for comparison to avoid extra NPU memory + # usage in CI testing environments with limited NPU memory + and torch.equal( + target_embed_tokens.weight.cpu(), + self.model.model.embed_tokens.weight.cpu(), + ) + ): + share_embeddings = True + logger.info( + "Detected EAGLE model with embed_tokens identical to the target" + " model. Sharing target model embedding weights with the draft" + " model." + ) + else: + logger.info( + "Detected EAGLE model with distinct embed_tokens weights. " + "Keeping separate embedding weights from the target model." + ) else: + # MTP model + share_embeddings = True logger.info( - " The EAGLE head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." + "Detected MTP model. " + "Sharing target model embedding weights with the draft model." ) + + if share_embeddings: + if hasattr(self.model.model, "embed_tokens"): + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens else: logger.info( "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \