diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py index 1083557..64816ac 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py @@ -41,6 +41,7 @@ def test_mtp_torchair_correctness( "use_cached_graph": False, "graph_batch_sizes": [1, 2, 4], }, + "multistream_overlap_shared_expert": "True" }) as ref_llm: ref_outputs = ref_llm.generate(example_prompts, sampling_config) with VllmRunner(model_name, @@ -60,7 +61,8 @@ def test_mtp_torchair_correctness( "enabled": True, "use_cached_graph": False, "graph_batch_sizes": [1, 2, 4], - } + }, + "multistream_overlap_shared_expert": "True" }) as spec_llm: spec_outputs = spec_llm.generate(example_prompts, sampling_config) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py index 7feeba9..b9b18ac 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -17,6 +17,9 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase): config = PretrainedConfig(vocab_size=1000, hidden_size=768, rms_norm_eps=1e-5) + mocker.patch( + 'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size', + return_value=1) mocker.patch( "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", return_value=None) @@ -56,6 +59,8 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase): mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768)) mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768), torch.randn(2, 3, 768)) + mtp_layer.enorm.return_value = torch.randn(2, 3, 768) + mtp_layer.hnorm.return_value = torch.randn(2, 3, 768) input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) positions = torch.tensor([[0, 1, 2], [0, 1, 2]]) @@ -65,7 +70,7 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase): output = mtp_layer(input_ids, positions, kv_cache, None, previous_hidden_states, inputs_embeds, 0) - assert output.shape == (2, 3, 768) + assert output.shape == (3, 768) class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index cc47122..8dc7efc 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -103,8 +103,6 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) - assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes num_decode_tokens = query_start_loc[first_prefill].item() diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py index a7c5a6e..b3760f7 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py @@ -24,6 +24,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -66,6 +67,7 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer ) -> None: nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, @@ -100,11 +102,15 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer hidden_states = self.eh_proj( torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - residual=None) + replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 + + hidden_states, residual = self.mtp_block( + positions=positions, + hidden_states=hidden_states, + residual=None, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + replace_allreduce=replace_allreduce) hidden_states = residual + hidden_states return hidden_states diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index a7ab345..3b69e25 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -975,7 +975,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # to save npu memory because they're no longer used. dispose_tensor(previous_hidden_states) dispose_tensor(previous_residual) - if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: + if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers: hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) @@ -1034,7 +1034,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE hidden_states *= 1. / self.routed_scaling_factor - if mla_moe_communication and self.layer_idx == self.layers - 1: + if mla_moe_communication and self.layer_idx >= self.layers - 1: hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) residual = tensor_model_parallel_all_gather(residual, dim=0)