[Feat] Shared expert dp for deepseek and deepseek_mtp (#3495)
### What this PR does / why we need it? shared expert dp for deepseek and deepseek_mtp, could be combined with sp to improve performance. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zhaozx-cn <zhaozx2116@163.com> Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -88,6 +88,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
inputs_embeds, True)
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
@@ -200,4 +202,6 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True)
|
||||
return hidden_states
|
||||
|
||||
@@ -122,8 +122,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
output_shape = hidden_states.shape
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
need_gather_q_kv = False
|
||||
if sp_enabled and self.debug_layer_idx < self.layers:
|
||||
need_gather_q_kv = True
|
||||
if not sp_enabled or self.debug_layer_idx < self.layers:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
# used in deepseek mtp layer
|
||||
output_shape = torch.chunk(hidden_states, self.tp_size,
|
||||
dim=0)[0].shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
|
||||
Reference in New Issue
Block a user