[Feat] Flash comm allgher ep (#3334)
Support flash comm v1(Sequence Parallelism) for Allgather EP. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
@@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None:
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@@ -500,9 +500,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess(self, magic_npu_fetch):
|
||||
def test_mla_preprocess(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad):
|
||||
magic_npu_fetch.return_value = MagicMock()
|
||||
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
|
||||
batch_size = 4
|
||||
seq_len = 8
|
||||
hidden_size = 1024
|
||||
|
||||
@@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
@patch("vllm_ascend.models.layers.mla.get_forward_context")
|
||||
@patch("torch.ops.vllm.mla_forward")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
mock_forward_context,
|
||||
mock_distributed, base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
# Make a fake ascend config because of the AscendLinearBase
|
||||
@@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
vllm_config.kv_transfer_config = None
|
||||
ascend_config.init_ascend_config(vllm_config)
|
||||
dummy_forward_context = MagicMock()
|
||||
dummy_forward_context.sp_enabled = False
|
||||
mock_forward_context.return_value = dummy_forward_context
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
|
||||
Reference in New Issue
Block a user