bugfix for mtp with multistream_moe (#3419)

### What this PR does / why we need it?
when infer deepseek mtp layer with multistream_moe, we should pass a
boolean to evaluate this feature and fix bugs when we are in mtp layer

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-10-15 08:59:58 +08:00
committed by GitHub
parent c2c1db78a7
commit 3642b64afc
5 changed files with 22 additions and 11 deletions

View File

@@ -41,6 +41,7 @@ def test_mtp_torchair_correctness(
"use_cached_graph": False, "use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4], "graph_batch_sizes": [1, 2, 4],
}, },
"multistream_overlap_shared_expert": "True"
}) as ref_llm: }) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config) ref_outputs = ref_llm.generate(example_prompts, sampling_config)
with VllmRunner(model_name, with VllmRunner(model_name,
@@ -60,7 +61,8 @@ def test_mtp_torchair_correctness(
"enabled": True, "enabled": True,
"use_cached_graph": False, "use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4], "graph_batch_sizes": [1, 2, 4],
} },
"multistream_overlap_shared_expert": "True"
}) as spec_llm: }) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config) spec_outputs = spec_llm.generate(example_prompts, sampling_config)

View File

@@ -17,6 +17,9 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
config = PretrainedConfig(vocab_size=1000, config = PretrainedConfig(vocab_size=1000,
hidden_size=768, hidden_size=768,
rms_norm_eps=1e-5) rms_norm_eps=1e-5)
mocker.patch(
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
return_value=1)
mocker.patch( mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None) return_value=None)
@@ -56,6 +59,8 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768)) mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768), mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768)) torch.randn(2, 3, 768))
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]]) positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
@@ -65,7 +70,7 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
output = mtp_layer(input_ids, positions, kv_cache, None, output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0) previous_hidden_states, inputs_embeds, 0)
assert output.shape == (2, 3, 768) assert output.shape == (3, 768)
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):

View File

@@ -103,8 +103,6 @@ def split_decodes_and_prefills(
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item() first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill num_decodes = first_prefill
num_prefills = num_reqs - num_decodes num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item() num_decode_tokens = query_start_loc[first_prefill].item()

View File

@@ -24,6 +24,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -66,6 +67,7 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, self.eh_proj = nn.Linear(config.hidden_size * 2,
@@ -100,11 +102,15 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
hidden_states = self.eh_proj( hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions, replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
hidden_states=hidden_states,
kv_cache=kv_cache, hidden_states, residual = self.mtp_block(
attn_metadata=attn_metadata, positions=positions,
residual=None) hidden_states=hidden_states,
residual=None,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
replace_allreduce=replace_allreduce)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return hidden_states return hidden_states

View File

@@ -975,7 +975,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# to save npu memory because they're no longer used. # to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states) dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual) dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers:
hidden_states = tensor_model_parallel_all_gather(hidden_states, hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0) dim=0)
@@ -1034,7 +1034,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# The scaling of DeepseekV2MOE output would be done in the forward # The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE # of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
if mla_moe_communication and self.layer_idx == self.layers - 1: if mla_moe_communication and self.layer_idx >= self.layers - 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0) dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0) residual = tensor_model_parallel_all_gather(residual, dim=0)