### What this PR does / why we need it?
This reverts commit
bf87606932.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E vllm serving with `enable_shared_expert_dp: true` in eager mode as
before.
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -144,17 +144,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
need_gather_q_kv = False
|
||||
if sp_enabled and self.debug_layer_idx < self.layers:
|
||||
need_gather_q_kv = True
|
||||
if not sp_enabled or self.debug_layer_idx < self.layers:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
# used in deepseek mtp layer
|
||||
output_shape = torch.chunk(hidden_states, self.tp_size,
|
||||
dim=0)[0].shape
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
|
||||
Reference in New Issue
Block a user