[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)
### What this PR does / why we need it? 1. Move additional functionalities from fused_moe.py to common_fused_moe.py and remove fused_moe.py 2. Remove unnecessary custom classes from qwen3_moe.py, and it will be completely removed after we release vllm-ascend v0.11.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing: 1. Enable/Disable EP 3. Aclgraph & eager 4. SP - vLLM version: v0.11.0 --------- Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com> Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -96,7 +96,7 @@ def mock_distributed():
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
|
||||
|
||||
|
||||
class DummyRMSNorm:
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
denom = (mean_sq + self.eps).sqrt()
|
||||
return x / denom
|
||||
|
||||
|
||||
class TestCustomQwen3MoeAttention(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.batch = 2
|
||||
self.seq_len = 3
|
||||
self.q_size = 8
|
||||
self.kv_size = 8
|
||||
self.head_dim = 4
|
||||
self.rms_eps = 1e-6
|
||||
|
||||
total_dim = self.q_size + 2 * self.kv_size
|
||||
|
||||
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
|
||||
dtype=torch.float32).reshape(
|
||||
self.batch, self.seq_len, total_dim)
|
||||
|
||||
def test_constant_input_normalization(self):
|
||||
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
|
||||
dtype=torch.float32)
|
||||
|
||||
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
|
||||
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
|
||||
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
|
||||
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
|
||||
|
||||
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
|
||||
|
||||
expected_q = torch.full((1, 1, self.q_size), norm_val)
|
||||
expected_k = torch.full((1, 1, self.kv_size), norm_val)
|
||||
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
|
||||
|
||||
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
|
||||
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
|
||||
self.assertTrue(torch.equal(v, expected_v))
|
||||
@@ -21,6 +21,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.original_num_experts = 8
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
@@ -196,7 +197,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
@@ -265,7 +265,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
# Run prepare
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
|
||||
@@ -24,8 +24,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
@@ -70,7 +69,7 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
@@ -104,24 +103,24 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
@@ -252,196 +251,6 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
pass
|
||||
|
||||
|
||||
class TestAscendFusedMoe:
|
||||
|
||||
def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
||||
layer = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
layer.w13_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['intermediate_size'] * 2,
|
||||
default_moe_config['hidden_size']))
|
||||
layer.w2_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['hidden_size'],
|
||||
default_moe_config['intermediate_size']))
|
||||
|
||||
assert layer.num_experts == default_moe_config['num_experts']
|
||||
assert layer.top_k == default_moe_config['top_k']
|
||||
assert hasattr(layer, 'w13_weight')
|
||||
assert hasattr(layer, 'w2_weight')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['use_grouped_topk'] = True
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['scoring_func'] = "random"
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
|
||||
moe = AscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert moe.quant_method == mock_quant_method
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"others_param",
|
||||
[[None,
|
||||
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
|
||||
[2, None, False, 5, None], [None, None, True, 5, None],
|
||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
||||
[None, None, False, 5, 1]])
|
||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
||||
|
||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
||||
inputs = torch.randn(num_tokens, 32)
|
||||
router_logits = torch.randn(num_tokens, 8)
|
||||
moe = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
if ep_size == 1:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
if shared_experts:
|
||||
assert output[0].shape == (num_tokens, 32)
|
||||
assert output[1].shape == (num_tokens, 10)
|
||||
else:
|
||||
assert output.shape == (num_tokens, 32)
|
||||
|
||||
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
|
||||
default_moe_config):
|
||||
inputs = torch.randn(5, 32)
|
||||
router_logits = torch.randn(5, 8)
|
||||
moe = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
moe.quant_method = MockQuantMethod(None, 5)
|
||||
output = moe._forward_ms_fused_moe_comp(inputs,
|
||||
router_logits,
|
||||
is_prefill=False,
|
||||
real_top_k=1)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
assert output.shape == (5, 32)
|
||||
|
||||
|
||||
class TestAscendUnquantizedFusedMoEMethod:
|
||||
|
||||
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
|
||||
layer = MagicMock()
|
||||
layer.w13_weight.data = torch.randn(16, 32)
|
||||
layer.w2_weight.data = torch.randn(16, 32)
|
||||
|
||||
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
|
||||
patch('vllm_ascend.utils.is_310p', return_value=False):
|
||||
moe_method.process_weights_after_loading(layer)
|
||||
|
||||
assert isinstance(layer.w13_weight, torch.nn.Parameter)
|
||||
assert isinstance(layer.w2_weight, torch.nn.Parameter)
|
||||
assert not layer.w13_weight.requires_grad
|
||||
assert not layer.w2_weight.requires_grad
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=128,
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
|
||||
Reference in New Issue
Block a user