[Feat] Flash comm allgher ep (#3334)

Support flash comm v1(Sequence Parallelism) for Allgather EP.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
realliujiaxu
2025-10-15 19:36:32 +08:00
committed by GitHub
parent 8abe517870
commit f69a83b7ba
15 changed files with 283 additions and 78 deletions

View File

@@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None:
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
example_prompts = [
"Hello, my name is",

View File

@@ -500,9 +500,12 @@ class TestAscendMLAImpl(TestBase):
mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
def test_mla_preprocess(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4
seq_len = 8
hidden_size = 1024

View File

@@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64)
@patch("vllm_ascend.models.layers.mla.get_forward_context")
@patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_forward_context,
mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase
@@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
dummy_forward_context = MagicMock()
dummy_forward_context.sp_enabled = False
mock_forward_context.return_value = dummy_forward_context
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,