[Feat] Shared expert dp for deepseek and deepseek_mtp (#3495)

### What this PR does / why we need it?
shared expert dp for deepseek and deepseek_mtp, could be combined with
sp to improve performance.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zhaozx-cn <zhaozx2116@163.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
zhaozx-cn
2025-10-17 15:06:37 +08:00
committed by GitHub
parent d9ee491f70
commit bf87606932
9 changed files with 57 additions and 10 deletions

View File

@@ -122,8 +122,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
need_gather_q_kv = get_forward_context().sp_enabled
output_shape = hidden_states.shape
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
need_gather_q_kv = False
if sp_enabled and self.debug_layer_idx < self.layers:
need_gather_q_kv = True
if not sp_enabled or self.debug_layer_idx < self.layers:
output_shape = hidden_states.shape
else:
# used in deepseek mtp layer
output_shape = torch.chunk(hidden_states, self.tp_size,
dim=0)[0].shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,