[BugFix] Fix deepseek v3.2 mtp bug. (#3900)
### What this PR does / why we need it?
This PR fixes deepseek v3.2 mtp bug.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
All existed ci tests should pass.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from vllm.model_executor.model_loader import get_model_loader
|
|||||||
from vllm.model_executor.model_loader.utils import \
|
from vllm.model_executor.model_loader.utils import \
|
||||||
process_weights_after_loading
|
process_weights_after_loading
|
||||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||||
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
@@ -139,6 +140,9 @@ class MtpProposer(Proposer):
|
|||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config,
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
AttentionLayerBase).keys())
|
AttentionLayerBase).keys())
|
||||||
|
target_indexer_layer_names = set(
|
||||||
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
DeepseekV32IndexerCache).keys())
|
||||||
draft_model_config = \
|
draft_model_config = \
|
||||||
self.vllm_config.speculative_config.draft_model_config
|
self.vllm_config.speculative_config.draft_model_config
|
||||||
target_device = self.vllm_config.device_config.device
|
target_device = self.vllm_config.device_config.device
|
||||||
@@ -152,6 +156,13 @@ class MtpProposer(Proposer):
|
|||||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||||
self.vllm_config, AttentionLayerBase).keys() -
|
self.vllm_config, AttentionLayerBase).keys() -
|
||||||
target_attn_layer_names)
|
target_attn_layer_names)
|
||||||
|
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
DeepseekV32IndexerCache)
|
||||||
|
draft_indexer_layer_names = indexer_layers.keys(
|
||||||
|
) - target_indexer_layer_names
|
||||||
|
# NOTE: Currently we don't have specific attention backend and attention metadata
|
||||||
|
# for deepseek v3.2 indexer, so we just exclude the indexer layers here.
|
||||||
|
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
||||||
|
|
||||||
assert len(draft_attn_layer_names) == 1
|
assert len(draft_attn_layer_names) == 1
|
||||||
self.attn_layer_name = list(draft_attn_layer_names)
|
self.attn_layer_name = list(draft_attn_layer_names)
|
||||||
|
|||||||
Reference in New Issue
Block a user