[Feat] shared expert dp for deepseek_mtp (#3811)
### What this PR does / why we need it? Support shared expert DP for deepseek_mtp feature. `shared_expert_dp` requires `SP==True`, with corresponding parameter restrictions. Previously, due to the coupling between `shared_expert_dp` and torchair, and the removal of `deepseek_mtp` in vllm_ascend, shared expert dp of deepseek_mtp was temporarily removed. Currently, by performing the `reduce_scatter` on the input of deepssek_mtp in `mtp_proposer.py`, we ensure that it matches the dimensions of `input_embedding`, and then perform the `all_gather` on the output of mtp. ### How was this patch tested? baseline: <img width="1184" height="692" alt="image" src="https://github.com/user-attachments/assets/9680d53a-7b1d-481a-accc-b8f3dae2b9e3" /> enable shared_expert_dp and multistream_overlap_shared_expert: <img width="1167" height="687" alt="image" src="https://github.com/user-attachments/assets/2531d06b-dfda-4e24-8628-6f4b0f677ddc" /> TPOT: 48ms -> 45.4ms Average TPS per rank: 117.6 -> 126.1 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: chenmenglong <chenmenglong1@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Co-authored-by: zengran <zengran2@huawei.com>
This commit is contained in:
@@ -32,7 +32,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable)
|
||||
prefill_context_parallel_enable,
|
||||
shared_expert_dp_enabled)
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
@@ -94,6 +95,7 @@ class MtpProposer(Proposer):
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
||||
|
||||
self.pcp_size = self.runner.pcp_size
|
||||
self.dcp_size = self.runner.dcp_size
|
||||
@@ -286,6 +288,12 @@ class MtpProposer(Proposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
is_mtp_model=True):
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = positions.unsqueeze(-1)
|
||||
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
||||
positions = positions.squeeze(-1)
|
||||
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
previous_hidden_states)
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
@@ -294,9 +302,13 @@ class MtpProposer(Proposer):
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
positions.shape[0],
|
||||
self.update_stream, forward_context, num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions, True)
|
||||
previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
previous_hidden_states, True)
|
||||
dummy_compute_logits(previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
@@ -675,7 +687,8 @@ class MtpProposer(Proposer):
|
||||
|
||||
moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
|
||||
|
||||
if scheduler_output:
|
||||
# Enable shared_expert_dp and MTP FULL graph may cause accuracy issues.
|
||||
if scheduler_output and not self.enable_shared_expert_dp:
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
uniform_decode = (max_query_len in list(
|
||||
range(1, self.num_speculative_tokens +
|
||||
@@ -725,11 +738,22 @@ class MtpProposer(Proposer):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens])
|
||||
if self.enable_shared_expert_dp:
|
||||
# positions [N] -> [N, 1] for padding
|
||||
positions = positions.unsqueeze(-1)
|
||||
positions = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
positions)
|
||||
positions = positions.squeeze(-1)
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
@@ -738,6 +762,12 @@ class MtpProposer(Proposer):
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), True)
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions.contiguous(), True)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
@@ -805,20 +835,21 @@ class MtpProposer(Proposer):
|
||||
batch_size,
|
||||
attn_metadata_i.decode.actual_seq_lengths_q)
|
||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
||||
exceeds_max_model_len = positions[:
|
||||
batch_size] >= self.runner.model_config.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
positions[:batch_size])
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
|
||||
Reference in New Issue
Block a user