[Feat] Flash comm allgher ep (#3334)

Support flash comm v1(Sequence Parallelism) for Allgather EP.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
realliujiaxu
2025-10-15 19:36:32 +08:00
committed by GitHub
parent 8abe517870
commit f69a83b7ba
15 changed files with 283 additions and 78 deletions

View File

@@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64)
@patch("vllm_ascend.models.layers.mla.get_forward_context")
@patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_forward_context,
mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase
@@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
dummy_forward_context = MagicMock()
dummy_forward_context.sp_enabled = False
mock_forward_context.return_value = dummy_forward_context
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,