bugfix for mtp with multistream_moe (#3419)
### What this PR does / why we need it? when infer deepseek mtp layer with multistream_moe, we should pass a boolean to evaluate this feature and fix bugs when we are in mtp layer - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -41,6 +41,7 @@ def test_mtp_torchair_correctness(
|
|||||||
"use_cached_graph": False,
|
"use_cached_graph": False,
|
||||||
"graph_batch_sizes": [1, 2, 4],
|
"graph_batch_sizes": [1, 2, 4],
|
||||||
},
|
},
|
||||||
|
"multistream_overlap_shared_expert": "True"
|
||||||
}) as ref_llm:
|
}) as ref_llm:
|
||||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||||
with VllmRunner(model_name,
|
with VllmRunner(model_name,
|
||||||
@@ -60,7 +61,8 @@ def test_mtp_torchair_correctness(
|
|||||||
"enabled": True,
|
"enabled": True,
|
||||||
"use_cached_graph": False,
|
"use_cached_graph": False,
|
||||||
"graph_batch_sizes": [1, 2, 4],
|
"graph_batch_sizes": [1, 2, 4],
|
||||||
}
|
},
|
||||||
|
"multistream_overlap_shared_expert": "True"
|
||||||
}) as spec_llm:
|
}) as spec_llm:
|
||||||
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
|
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
|||||||
config = PretrainedConfig(vocab_size=1000,
|
config = PretrainedConfig(vocab_size=1000,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
rms_norm_eps=1e-5)
|
rms_norm_eps=1e-5)
|
||||||
|
mocker.patch(
|
||||||
|
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
|
||||||
|
return_value=1)
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||||
return_value=None)
|
return_value=None)
|
||||||
@@ -56,6 +59,8 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
|||||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
||||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
||||||
torch.randn(2, 3, 768))
|
torch.randn(2, 3, 768))
|
||||||
|
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
|
||||||
|
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
|
||||||
|
|
||||||
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||||
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||||
@@ -65,7 +70,7 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
|||||||
|
|
||||||
output = mtp_layer(input_ids, positions, kv_cache, None,
|
output = mtp_layer(input_ids, positions, kv_cache, None,
|
||||||
previous_hidden_states, inputs_embeds, 0)
|
previous_hidden_states, inputs_embeds, 0)
|
||||||
assert output.shape == (2, 3, 768)
|
assert output.shape == (3, 768)
|
||||||
|
|
||||||
|
|
||||||
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
||||||
|
|||||||
@@ -103,8 +103,6 @@ def split_decodes_and_prefills(
|
|||||||
return num_reqs, 0, num_tokens, 0
|
return num_reqs, 0, num_tokens, 0
|
||||||
|
|
||||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||||
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
|
||||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
|
||||||
num_decodes = first_prefill
|
num_decodes = first_prefill
|
||||||
num_prefills = num_reqs - num_decodes
|
num_prefills = num_reqs - num_decodes
|
||||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch.nn as nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@@ -66,6 +67,7 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
|||||||
) -> None:
|
) -> None:
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||||
@@ -100,11 +102,15 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
|||||||
hidden_states = self.eh_proj(
|
hidden_states = self.eh_proj(
|
||||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||||
|
|
||||||
hidden_states, residual = self.mtp_block(positions=positions,
|
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
hidden_states, residual = self.mtp_block(
|
||||||
attn_metadata=attn_metadata,
|
positions=positions,
|
||||||
residual=None)
|
hidden_states=hidden_states,
|
||||||
|
residual=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
replace_allreduce=replace_allreduce)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -975,7 +975,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# to save npu memory because they're no longer used.
|
# to save npu memory because they're no longer used.
|
||||||
dispose_tensor(previous_hidden_states)
|
dispose_tensor(previous_hidden_states)
|
||||||
dispose_tensor(previous_residual)
|
dispose_tensor(previous_residual)
|
||||||
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
|
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers:
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||||
dim=0)
|
dim=0)
|
||||||
|
|
||||||
@@ -1034,7 +1034,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||||
# of DeepseekV2MOE
|
# of DeepseekV2MOE
|
||||||
hidden_states *= 1. / self.routed_scaling_factor
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
if mla_moe_communication and self.layer_idx >= self.layers - 1:
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||||
dim=0)
|
dim=0)
|
||||||
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user