[Feat]support sequence parallelism by pass for VL models (#5632)
This commit is contained in:
@@ -150,7 +150,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
kv_cache: torch.Tensor | None = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
Reference in New Issue
Block a user