### What this PR does / why we need it?
In PR https://github.com/vllm-project/vllm-ascend/pull/3420, we
initially placed the quantization type (quant_type) in the MoECommMethod
class. However, since MoECommMethod follows a singleton pattern, it
couldn't accommodate scenarios where different layers in the model might
use different quantization approaches (e.g., MTP modules using
floating-point computation while the main model employs quantized
computation).
In this PR, we've moved the quantization type to the AscendFusedMoe
class and pass it as a parameter to MoECommMethod.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
export HCCL_BUFFSIZE=1024
export VLLM_VERSION=0.11.0
vllm serve /home/data/DeepSeek-R1_w8a8/ \
--data-parallel-size 2 \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--served-model-name dsv3 \
--max-model-len 32768 \
--max-num-batched-tokens 4096 \
--max-num-seqs 16 \
--quantization ascend \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
```
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
246 lines
10 KiB
Python
246 lines
10 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
|
AlltoAllCommImpl,
|
|
MC2CommImpl)
|
|
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
|
|
|
|
|
class TestMoECommMethod(TestBase):
|
|
|
|
def setUp(self):
|
|
# Mock FusedMoEConfig
|
|
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
|
self.moe_config.num_experts = 8
|
|
self.moe_config.num_local_experts = 2
|
|
self.moe_config.experts_per_token = 2
|
|
self.moe_config.tp_group = MagicMock()
|
|
self.moe_config.tp_group.device_group = MagicMock()
|
|
self.moe_config.dp_size = 1
|
|
self.moe_config.tp_size = 1
|
|
self.moe_config.ep_size = 1
|
|
self.moe_config.dp_group = MagicMock()
|
|
self.moe_config.num_global_redundant_experts = 0
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
|
)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
|
)
|
|
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
|
mock_prepare_finalize,
|
|
mock_get_forward_context,
|
|
mock_get_current_vllm_config):
|
|
# Mock vLLM config
|
|
mock_get_current_vllm_config.return_value = MagicMock()
|
|
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "all_gather"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2), None, None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Create instance
|
|
comm_impl = AllGatherCommImpl(self.moe_config)
|
|
|
|
# Test prepare method
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
|
hidden_states, router_logits)
|
|
|
|
# Verify prepare was called with correct arguments
|
|
mock_pf_instance.prepare.assert_called_once_with(
|
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
|
|
|
# Test finalize method
|
|
comm_impl.finalize(h_out,
|
|
reduce_results=True,
|
|
context_metadata=context_metadata)
|
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
|
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
|
mock_get_forward_context,
|
|
mock_get_current_vllm_config):
|
|
# Mock vLLM config
|
|
mock_get_current_vllm_config.return_value = MagicMock()
|
|
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "mc2"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2),
|
|
torch.tensor([1, 0, 1,
|
|
0]), None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Create instance
|
|
comm_impl = MC2CommImpl(self.moe_config)
|
|
|
|
# Test prepare method
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
|
hidden_states, router_logits)
|
|
|
|
# Verify prepare was called with correct arguments
|
|
mock_pf_instance.prepare.assert_called_once_with(
|
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
|
|
|
# Test finalize method
|
|
comm_impl.finalize(h_out,
|
|
reduce_results=True,
|
|
context_metadata=context_metadata)
|
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
|
)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
|
|
)
|
|
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
|
mock_prepare_finalize,
|
|
mock_get_forward_context,
|
|
mock_get_current_vllm_config):
|
|
# Mock vLLM config
|
|
mock_get_current_vllm_config.return_value = MagicMock()
|
|
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "alltoall"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2), None, None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Create instance
|
|
comm_impl = AlltoAllCommImpl(self.moe_config)
|
|
|
|
# Test prepare method
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
|
hidden_states, router_logits)
|
|
|
|
# Verify prepare was called with correct arguments
|
|
mock_pf_instance.prepare.assert_called_once_with(
|
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
|
)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
|
)
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
|
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
|
mock_token_dispatcher, mock_prepare_finalize,
|
|
mock_get_forward_context,
|
|
mock_get_current_vllm_config):
|
|
# Mock vLLM config
|
|
mock_get_current_vllm_config.return_value = MagicMock()
|
|
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "all_gather"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2), None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_td_instance.token_dispatch.return_value = {
|
|
"hidden_states": torch.randn(6, 8),
|
|
"group_list": torch.tensor([2, 2, 2]),
|
|
"group_list_type": 1
|
|
}
|
|
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Mock unified_apply_mlp
|
|
mock_unified_apply_mlp.return_value = torch.randn(6, 8)
|
|
|
|
# Create instance
|
|
comm_impl = AllGatherCommImpl(self.moe_config)
|
|
|
|
# Test fused_experts method
|
|
hidden_states = torch.randn(4, 8).contiguous()
|
|
w1 = torch.randn(16, 8).contiguous()
|
|
w2 = torch.randn(16, 8).contiguous()
|
|
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
|
[0.6, 0.4]])
|
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
|
|
|
# Make sure tensors are contiguous and have correct strides
|
|
hidden_states = hidden_states.contiguous()
|
|
w1 = w1.contiguous()
|
|
w2 = w2.contiguous()
|
|
|
|
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
activation="silu")
|
|
|
|
# Verify result shape
|
|
self.assertEqual(result.shape, (4, 8))
|
|
|
|
# Verify token_dispatch was called
|
|
mock_td_instance.token_dispatch.assert_called_once()
|
|
|
|
# Verify unified_apply_mlp was called
|
|
mock_unified_apply_mlp.assert_called_once()
|
|
|
|
# Verify token_combine was called
|
|
mock_td_instance.token_combine.assert_called_once()
|