[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
kv_cache: torch.Tensor | None = None,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
need_gather_q_kv = get_forward_context().sp_enabled
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
output_shape = hidden_states.shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)