[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)
### What this PR does / why we need it? 1. Move additional functionalities from fused_moe.py to common_fused_moe.py and remove fused_moe.py 2. Remove unnecessary custom classes from qwen3_moe.py, and it will be completely removed after we release vllm-ascend v0.11.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing: 1. Enable/Disable EP 3. Aclgraph & eager 4. SP - vLLM version: v0.11.0 --------- Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com> Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -96,7 +96,7 @@ def mock_distributed():
|
|||||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
|
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
|
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
|
||||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||||
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
import math
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
|
|
||||||
|
|
||||||
|
|
||||||
class DummyRMSNorm:
|
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
|
||||||
self.dim = dim
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
|
|
||||||
denom = (mean_sq + self.eps).sqrt()
|
|
||||||
return x / denom
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomQwen3MoeAttention(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.batch = 2
|
|
||||||
self.seq_len = 3
|
|
||||||
self.q_size = 8
|
|
||||||
self.kv_size = 8
|
|
||||||
self.head_dim = 4
|
|
||||||
self.rms_eps = 1e-6
|
|
||||||
|
|
||||||
total_dim = self.q_size + 2 * self.kv_size
|
|
||||||
|
|
||||||
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
|
|
||||||
dtype=torch.float32).reshape(
|
|
||||||
self.batch, self.seq_len, total_dim)
|
|
||||||
|
|
||||||
def test_constant_input_normalization(self):
|
|
||||||
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
|
|
||||||
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
|
|
||||||
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
|
|
||||||
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
|
|
||||||
|
|
||||||
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
|
|
||||||
|
|
||||||
expected_q = torch.full((1, 1, self.q_size), norm_val)
|
|
||||||
expected_k = torch.full((1, 1, self.kv_size), norm_val)
|
|
||||||
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
|
|
||||||
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
|
|
||||||
self.assertTrue(torch.equal(v, expected_v))
|
|
||||||
@@ -21,6 +21,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
|||||||
self.moe_config.tp_size = 1
|
self.moe_config.tp_size = 1
|
||||||
self.moe_config.ep_size = 1
|
self.moe_config.ep_size = 1
|
||||||
self.moe_config.dp_group = MagicMock()
|
self.moe_config.dp_group = MagicMock()
|
||||||
|
self.moe_config.original_num_experts = 8
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||||
@@ -196,7 +197,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
|||||||
|
|
||||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
rm_router_logits=False,
|
|
||||||
gate=mock_gate)
|
gate=mock_gate)
|
||||||
|
|
||||||
# After all-gather with DP=2, should double the batch size
|
# After all-gather with DP=2, should double the batch size
|
||||||
@@ -265,7 +265,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
|||||||
# Run prepare
|
# Run prepare
|
||||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
rm_router_logits=False,
|
|
||||||
gate=mock_gate)
|
gate=mock_gate)
|
||||||
|
|
||||||
# Should be global tensor: [7, 8] and [7, 2]
|
# Should be global tensor: [7, 8] and [7, 2]
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||||
AscendUnquantizedFusedMoEMethod)
|
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||||
@@ -70,7 +69,7 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
|||||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||||
mock_vllm_config.model_config.max_model_len = 2048
|
mock_vllm_config.model_config.max_model_len = 2048
|
||||||
|
|
||||||
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
|
||||||
return_value=mock_vllm_config)
|
return_value=mock_vllm_config)
|
||||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||||
return_value=mock_vllm_config)
|
return_value=mock_vllm_config)
|
||||||
@@ -104,24 +103,24 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
|
|
||||||
with patch('torch.distributed.get_rank', return_value=0), \
|
with patch('torch.distributed.get_rank', return_value=0), \
|
||||||
patch('torch.distributed.get_world_size', return_value=4), \
|
patch('torch.distributed.get_world_size', return_value=4), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||||
return_value=mock_dp_and_tp_group(mocker)), \
|
return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
|
||||||
return_value=MagicMock(
|
return_value=MagicMock(
|
||||||
torchair_graph_config=MagicMock(enabled=False),
|
torchair_graph_config=MagicMock(enabled=False),
|
||||||
enable_multistream_moe=False,
|
enable_multistream_moe=False,
|
||||||
expert_map_path=None
|
expert_map_path=None
|
||||||
)), \
|
)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
|
||||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
|
||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
@@ -252,196 +251,6 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestAscendFusedMoe:
|
|
||||||
|
|
||||||
def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
|
||||||
layer = AscendFusedMoE(**default_moe_config)
|
|
||||||
|
|
||||||
layer.w13_weight = nn.Parameter(
|
|
||||||
torch.randn(default_moe_config['num_experts'],
|
|
||||||
default_moe_config['intermediate_size'] * 2,
|
|
||||||
default_moe_config['hidden_size']))
|
|
||||||
layer.w2_weight = nn.Parameter(
|
|
||||||
torch.randn(default_moe_config['num_experts'],
|
|
||||||
default_moe_config['hidden_size'],
|
|
||||||
default_moe_config['intermediate_size']))
|
|
||||||
|
|
||||||
assert layer.num_experts == default_moe_config['num_experts']
|
|
||||||
assert layer.top_k == default_moe_config['top_k']
|
|
||||||
assert hasattr(layer, 'w13_weight')
|
|
||||||
assert hasattr(layer, 'w2_weight')
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
error_config = default_moe_config.copy()
|
|
||||||
error_config['use_grouped_topk'] = True
|
|
||||||
layer = AscendFusedMoE(**error_config)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
error_config = default_moe_config.copy()
|
|
||||||
error_config['scoring_func'] = "random"
|
|
||||||
layer = AscendFusedMoE(**error_config)
|
|
||||||
|
|
||||||
def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
|
||||||
mock_quant_config = MagicMock()
|
|
||||||
mock_quant_method = MockFusedMoEMethod()
|
|
||||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
|
||||||
|
|
||||||
moe = AscendFusedMoE(**default_moe_config,
|
|
||||||
quant_config=mock_quant_config)
|
|
||||||
|
|
||||||
assert moe.quant_method is not None
|
|
||||||
assert moe.quant_method == mock_quant_method
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"others_param",
|
|
||||||
[[None,
|
|
||||||
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
|
|
||||||
[2, None, False, 5, None], [None, None, True, 5, None],
|
|
||||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
|
||||||
[None, None, False, 5, 1]])
|
|
||||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
|
||||||
|
|
||||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
|
||||||
inputs = torch.randn(num_tokens, 32)
|
|
||||||
router_logits = torch.randn(num_tokens, 8)
|
|
||||||
moe = AscendFusedMoE(**default_moe_config)
|
|
||||||
|
|
||||||
if ep_size == 1:
|
|
||||||
moe.moe_parallel_config.ep_size = 1
|
|
||||||
|
|
||||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
|
||||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
|
||||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
|
||||||
return_value=forward_context):
|
|
||||||
output = moe.forward(inputs,
|
|
||||||
router_logits,
|
|
||||||
is_prefill=is_prefill,
|
|
||||||
top_k=top_k,
|
|
||||||
shared_experts=shared_experts)
|
|
||||||
|
|
||||||
moe.quant_method.apply.assert_called_once()
|
|
||||||
|
|
||||||
if shared_experts:
|
|
||||||
assert output[0].shape == (num_tokens, 32)
|
|
||||||
assert output[1].shape == (num_tokens, 10)
|
|
||||||
else:
|
|
||||||
assert output.shape == (num_tokens, 32)
|
|
||||||
|
|
||||||
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
|
|
||||||
default_moe_config):
|
|
||||||
inputs = torch.randn(5, 32)
|
|
||||||
router_logits = torch.randn(5, 8)
|
|
||||||
moe = AscendFusedMoE(**default_moe_config)
|
|
||||||
|
|
||||||
moe.quant_method = MockQuantMethod(None, 5)
|
|
||||||
output = moe._forward_ms_fused_moe_comp(inputs,
|
|
||||||
router_logits,
|
|
||||||
is_prefill=False,
|
|
||||||
real_top_k=1)
|
|
||||||
|
|
||||||
moe.quant_method.apply.assert_called_once()
|
|
||||||
|
|
||||||
assert output.shape == (5, 32)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAscendUnquantizedFusedMoEMethod:
|
|
||||||
|
|
||||||
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
|
|
||||||
layer = MagicMock()
|
|
||||||
layer.w13_weight.data = torch.randn(16, 32)
|
|
||||||
layer.w2_weight.data = torch.randn(16, 32)
|
|
||||||
|
|
||||||
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
|
|
||||||
patch('vllm_ascend.utils.is_310p', return_value=False):
|
|
||||||
moe_method.process_weights_after_loading(layer)
|
|
||||||
|
|
||||||
assert isinstance(layer.w13_weight, torch.nn.Parameter)
|
|
||||||
assert isinstance(layer.w2_weight, torch.nn.Parameter)
|
|
||||||
assert not layer.w13_weight.requires_grad
|
|
||||||
assert not layer.w2_weight.requires_grad
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("others_param",
|
|
||||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
|
||||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
|
||||||
mock_moe_env, others_param):
|
|
||||||
global_num_experts, ep_size = others_param
|
|
||||||
is_prefill = False
|
|
||||||
|
|
||||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
|
||||||
|
|
||||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
|
||||||
return_value=forward_context):
|
|
||||||
moe_method.ep_size = ep_size
|
|
||||||
x = torch.randn(8, 2, 2)
|
|
||||||
router_logits = torch.randn(8, 8)
|
|
||||||
layer = MagicMock()
|
|
||||||
local_num_experts = 2
|
|
||||||
hidden_size = 2
|
|
||||||
intermediate_size_per_partition = 4
|
|
||||||
|
|
||||||
layer.w13_weight = torch.randn(local_num_experts,
|
|
||||||
intermediate_size_per_partition * 2,
|
|
||||||
hidden_size)
|
|
||||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
|
||||||
intermediate_size_per_partition)
|
|
||||||
|
|
||||||
result = moe_method.apply(layer=layer,
|
|
||||||
x=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=2,
|
|
||||||
renormalize=True,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
is_prefill=is_prefill)
|
|
||||||
|
|
||||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
|
||||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
|
||||||
|
|
||||||
expected_shape = (16, 2)
|
|
||||||
assert result.shape == expected_shape
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
|
||||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
|
||||||
mock_moe_env, others_param):
|
|
||||||
ep_size = others_param
|
|
||||||
is_prefill = False
|
|
||||||
|
|
||||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
|
||||||
|
|
||||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
|
||||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
|
||||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
|
||||||
moe_method.ep_size = ep_size
|
|
||||||
x = torch.randn(8, 2, 2)
|
|
||||||
if ep_size == 1:
|
|
||||||
x = x.view(-1, 2)
|
|
||||||
router_logits = torch.randn(8, 8)
|
|
||||||
layer = MagicMock()
|
|
||||||
|
|
||||||
local_num_experts = 2
|
|
||||||
hidden_size = 2
|
|
||||||
intermediate_size_per_partition = 4
|
|
||||||
layer.w13_weight = torch.randn(local_num_experts,
|
|
||||||
intermediate_size_per_partition * 2,
|
|
||||||
hidden_size)
|
|
||||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
|
||||||
intermediate_size_per_partition)
|
|
||||||
|
|
||||||
result = moe_method.apply(layer=layer,
|
|
||||||
x=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=2,
|
|
||||||
renormalize=True,
|
|
||||||
global_num_experts=128,
|
|
||||||
expert_map=expert_map,
|
|
||||||
is_prefill=is_prefill)
|
|
||||||
|
|
||||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
|
||||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
|
||||||
|
|
||||||
expected_shape = (16, 2)
|
|
||||||
assert result.shape == expected_shape
|
|
||||||
|
|
||||||
|
|
||||||
class TestExpertsSelector:
|
class TestExpertsSelector:
|
||||||
|
|
||||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, False, None)
|
hidden_states, router_logits, False, False, None)
|
||||||
|
|
||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out, reduce_results=True)
|
comm_impl.finalize(h_out, reduce_results=True)
|
||||||
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, False, None)
|
hidden_states, router_logits, False, False, None)
|
||||||
|
|
||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out, reduce_results=True)
|
comm_impl.finalize(h_out, reduce_results=True)
|
||||||
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, False, None)
|
hidden_states, router_logits, False, False, None)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||||
|
|||||||
@@ -45,10 +45,6 @@ def register_model():
|
|||||||
"DeepSeekMTPModel",
|
"DeepSeekMTPModel",
|
||||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
|
||||||
"Qwen3MoeForCausalLM",
|
|
||||||
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
|
||||||
|
|
||||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
|||||||
from vllm_ascend.models.layers.mla import AscendMLAModules
|
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||||
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
|
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
|
||||||
AscendSparseFlashAttention, Indexer)
|
AscendSparseFlashAttention, Indexer)
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||||
|
|||||||
@@ -1,263 +0,0 @@
|
|||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# Copyright 2024 The Qwen team.
|
|
||||||
# Copyright 2023 The vLLM team.
|
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
|
||||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
|
||||||
get_tp_group)
|
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
|
||||||
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
|
|
||||||
SupportsLoRA, SupportsPP)
|
|
||||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
|
||||||
Qwen3MoeDecoderLayer,
|
|
||||||
Qwen3MoeForCausalLM,
|
|
||||||
Qwen3MoeMLP, Qwen3MoeModel,
|
|
||||||
Qwen3MoeSparseMoeBlock)
|
|
||||||
from vllm.model_executor.models.utils import (
|
|
||||||
PPMissingLayer, extract_layer_index,
|
|
||||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if self.tp_size > config.num_experts:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
|
||||||
f"the number of experts {config.num_experts}.")
|
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.num_experts,
|
|
||||||
bias=False,
|
|
||||||
quant_config=None,
|
|
||||||
prefix=f"{prefix}.gate",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.experts = AscendFusedMoE(
|
|
||||||
num_experts=config.num_experts,
|
|
||||||
top_k=config.num_experts_per_tok,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.moe_intermediate_size,
|
|
||||||
reduce_results=False,
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.experts",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
|
|
||||||
self.dp_size = get_dp_group().world_size
|
|
||||||
|
|
||||||
self.tp_group = get_tp_group().device_group
|
|
||||||
self.tp_rank = get_tp_group().rank_in_group
|
|
||||||
self.ep_group = get_ep_group()
|
|
||||||
|
|
||||||
self.params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
attn_metadata=None,
|
|
||||||
):
|
|
||||||
if attn_metadata is None:
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
# when profile runs, force experts to load balanced tokens
|
|
||||||
# to avoid high memory consumption on a single rank.
|
|
||||||
enable_force_load_balance = get_forward_context().in_profile_run
|
|
||||||
is_prefill = get_forward_context().with_prefill
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
|
||||||
router_logits, _ = self.gate(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.experts(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
is_prefill=is_prefill,
|
|
||||||
top_k=self.top_k,
|
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
|
||||||
shared_experts=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
vllm_config: Optional[VllmConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
|
||||||
8192)
|
|
||||||
self.self_attn = Qwen3MoeAttention(
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
num_kv_heads=config.num_key_value_heads,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
rope_scaling=rope_scaling,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
|
||||||
qkv_bias=getattr(config, 'attention_bias', False),
|
|
||||||
head_dim=getattr(config, 'head_dim', None),
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.self_attn",
|
|
||||||
)
|
|
||||||
|
|
||||||
# `mlp_only_layers` in the config.
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
|
||||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
|
||||||
config.mlp_only_layers)
|
|
||||||
self.use_aclgraph = (vllm_config is not None
|
|
||||||
and vllm_config.compilation_config.level
|
|
||||||
== CompilationLevel.PIECEWISE
|
|
||||||
and not vllm_config.model_config.enforce_eager)
|
|
||||||
if (layer_idx not in mlp_only_layers) and (
|
|
||||||
config.num_experts > 0 and
|
|
||||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
|
||||||
if not self.use_aclgraph:
|
|
||||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
|
||||||
self.mlp = CustomSparseMoeBlock(config=config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.mlp")
|
|
||||||
else:
|
|
||||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
|
||||||
prefix=f"{prefix}.mlp")
|
|
||||||
else:
|
|
||||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.mlp")
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
|
||||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
|
||||||
eplb_config = parallel_config.eplb_config
|
|
||||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
|
||||||
self.padding_idx = config.pad_token_id
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.config = config
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
prefix=f"{prefix}.embed_tokens")
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
||||||
config.num_hidden_layers,
|
|
||||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
|
||||||
config=config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
vllm_config=vllm_config,
|
|
||||||
prefix=prefix),
|
|
||||||
prefix=f"{prefix}.layers",
|
|
||||||
)
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
make_empty_intermediate_tensors_factory(
|
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
SupportsPP.__init__(self)
|
|
||||||
SupportsLoRA.__init__(self)
|
|
||||||
MixtureOfExperts.__init__(self)
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
self.config = config
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "lm_head"))
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
# Set MoE hyperparameters
|
|
||||||
self.expert_weights: list[torch.Tensor] = []
|
|
||||||
|
|
||||||
self.moe_layers: list[FusedMoE] = []
|
|
||||||
example_layer = None
|
|
||||||
for layer in self.model.layers:
|
|
||||||
if isinstance(layer, PPMissingLayer):
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
|
||||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
|
||||||
example_layer = layer.mlp
|
|
||||||
self.moe_layers.append(layer.mlp.experts)
|
|
||||||
|
|
||||||
if example_layer is None:
|
|
||||||
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
|
|
||||||
|
|
||||||
self.num_moe_layers = len(self.moe_layers)
|
|
||||||
self.num_expert_groups = 1
|
|
||||||
self.num_shared_experts = 0
|
|
||||||
@@ -18,7 +18,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm_ascend.ops.common_fused_moe # noqa
|
import vllm_ascend.ops.common_fused_moe # noqa
|
||||||
import vllm_ascend.ops.fused_moe # noqa
|
|
||||||
import vllm_ascend.ops.layernorm # noqa
|
import vllm_ascend.ops.layernorm # noqa
|
||||||
import vllm_ascend.ops.register_custom_ops # noqa
|
import vllm_ascend.ops.register_custom_ops # noqa
|
||||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import os.path
|
import os.path
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
@@ -23,6 +23,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
|||||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||||
@@ -37,99 +38,110 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
|
|||||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||||
|
|
||||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
|
||||||
|
|
||||||
|
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||||
|
|
||||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
def __init__(self, moe: FusedMoEConfig = None):
|
||||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
# NOTE: Currently, this self.use_aclgraph is only used in
|
super().__init__(moe=moe)
|
||||||
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
|
||||||
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
|
||||||
# Once torch.randint_like is supported or removed, this flag can be removed.
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
if ascend_config.torchair_graph_config.enabled:
|
|
||||||
self.use_aclgraph = False
|
|
||||||
else:
|
|
||||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
|
||||||
== CompilationLevel.PIECEWISE
|
|
||||||
and not vllm_config.model_config.enforce_eager)
|
|
||||||
self.transpose = True
|
|
||||||
|
|
||||||
|
# NOTE: Currently, this self.use_aclgraph is only used in
|
||||||
|
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
||||||
|
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
||||||
|
# Once torch.randint_like is supported or removed, this flag can be removed.
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||||
|
if ascend_config.torchair_graph_config.enabled:
|
||||||
|
self.use_aclgraph = False
|
||||||
|
else:
|
||||||
|
self.use_aclgraph = (vllm_config.compilation_config.level
|
||||||
|
== CompilationLevel.PIECEWISE and
|
||||||
|
not vllm_config.model_config.enforce_eager)
|
||||||
|
self.transpose = True
|
||||||
|
|
||||||
def forward_oot(
|
def process_weights_after_loading(self, layer):
|
||||||
self,
|
super(UnquantizedFusedMoEMethod,
|
||||||
layer: torch.nn.Module,
|
self).process_weights_after_loading(layer)
|
||||||
x: torch.Tensor,
|
if self.transpose:
|
||||||
use_grouped_topk: bool,
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||||
top_k: int,
|
1, 2).contiguous()
|
||||||
router_logits: torch.Tensor,
|
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||||
renormalize: bool,
|
requires_grad=False)
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
scoring_func: str = "softmax",
|
|
||||||
routed_scaling_factor: float = 1.0,
|
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
enable_eplb: bool = False,
|
|
||||||
expert_load_view: Optional[torch.Tensor] = None,
|
|
||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
||||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
|
|
||||||
topk_weights, topk_ids, row_idx = select_experts(
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||||
hidden_states=x,
|
1, 2).contiguous()
|
||||||
router_logits=router_logits,
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
top_k=top_k,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
global_num_experts=global_num_experts)
|
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
self.transpose = False
|
||||||
return moe_comm_method.fused_experts(hidden_states=x,
|
else:
|
||||||
w1=layer.w13_weight,
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||||
w2=layer.w2_weight,
|
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||||
topk_weights=topk_weights,
|
requires_grad=False)
|
||||||
topk_ids=topk_ids,
|
|
||||||
row_idx=row_idx,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map)
|
|
||||||
|
|
||||||
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
if not is_310p():
|
||||||
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||||
if self.transpose:
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||||
1, 2).contiguous()
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
|
||||||
|
|
||||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
def apply(self,
|
||||||
1, 2).contiguous()
|
layer: torch.nn.Module,
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
enable_force_load_balance: bool = False,
|
||||||
|
shared_experts: Optional[Any] = None,
|
||||||
|
**kwargs) -> torch.Tensor:
|
||||||
|
|
||||||
self.transpose = False
|
topk_weights, topk_ids, row_idx = select_experts(
|
||||||
else:
|
hidden_states=x,
|
||||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
router_logits=router_logits,
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
top_k=top_k,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
global_num_experts=global_num_experts)
|
||||||
|
|
||||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
# this is a naive implementation for experts load balance so as
|
||||||
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
|
# currently it is only activated when doing profile runs.
|
||||||
|
if enable_force_load_balance and not self.use_aclgraph:
|
||||||
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
if not is_310p():
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
return moe_comm_method.fused_experts(
|
||||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
hidden_states=x,
|
||||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
w1=layer.w13_weight,
|
||||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
row_idx=row_idx,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
shared_experts=shared_experts,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
dynamic_eplb=self.dynamic_eplb)
|
||||||
|
|
||||||
|
|
||||||
class AscendFusedMoE(FusedMoE):
|
class AscendFusedMoE(FusedMoE):
|
||||||
@@ -138,8 +150,26 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
num_experts = kwargs["num_experts"]
|
||||||
|
intermediate_size = kwargs["intermediate_size"]
|
||||||
|
|
||||||
AscendFusedMoE.moe_counter += 1
|
AscendFusedMoE.moe_counter += 1
|
||||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
self.moe_instance_id = AscendFusedMoE.moe_counter
|
||||||
|
|
||||||
|
self.global_num_experts = num_experts
|
||||||
|
self.expert_map = None
|
||||||
|
self.log2phy = None
|
||||||
|
self.global_redundant_expert_num = 0
|
||||||
|
|
||||||
|
if self.quant_config is None:
|
||||||
|
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
||||||
|
self.moe_config)
|
||||||
|
else:
|
||||||
|
self.quant_method = self.quant_config.get_quant_method(
|
||||||
|
self, self.layer_name)
|
||||||
|
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
self.moe_config.tp_group = get_tp_group()
|
self.moe_config.tp_group = get_tp_group()
|
||||||
self.moe_config.dp_group = get_dp_group()
|
self.moe_config.dp_group = get_dp_group()
|
||||||
self.moe_config.ep_group = get_ep_group()
|
self.moe_config.ep_group = get_ep_group()
|
||||||
@@ -148,6 +178,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||||
self.expert_map_path = ascend_config.expert_map_path
|
self.expert_map_path = ascend_config.expert_map_path
|
||||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||||
|
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||||
# static eplb initializing with expert_map_path
|
# static eplb initializing with expert_map_path
|
||||||
if self.expert_map_path and os.path.exists(
|
if self.expert_map_path and os.path.exists(
|
||||||
self.expert_map_path) and os.access(self.expert_map_path,
|
self.expert_map_path) and os.access(self.expert_map_path,
|
||||||
@@ -180,6 +211,25 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||||
|
|
||||||
|
self.moe_config.num_experts = self.global_num_experts
|
||||||
|
self.moe_config.num_local_experts = self.local_num_experts
|
||||||
|
self.moe_config.original_num_experts = num_experts
|
||||||
|
|
||||||
|
moe_quant_params = {
|
||||||
|
"num_experts": local_num_experts,
|
||||||
|
"hidden_size": self.hidden_size,
|
||||||
|
"intermediate_size_per_partition":
|
||||||
|
self.intermediate_size_per_partition,
|
||||||
|
"params_dtype": self.params_dtype,
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
}
|
||||||
|
# need full intermediate size pre-sharding for WNA16 act order
|
||||||
|
if (self.quant_method.__class__.__name__
|
||||||
|
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||||
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
|
||||||
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
setup_moe_comm_method(self.moe_config)
|
setup_moe_comm_method(self.moe_config)
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
def update_expert_map(self, new_expert_map):
|
||||||
@@ -210,11 +260,18 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||||
|
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
enable_force_load_balance = forward_context.in_profile_run
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
replace_allreduce=forward_context.sp_enabled)
|
replace_allreduce=forward_context.sp_enabled,
|
||||||
|
enable_shared_expert_dp=self.enable_shared_expert_dp)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
@@ -233,11 +290,13 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
e_score_correction_bias=self.e_score_correction_bias,
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
enable_eplb=self.enable_eplb,
|
quantized_x_for_share=quantized_x_for_share,
|
||||||
expert_load_view=self.expert_load_view,
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||||
logical_to_physical_map=self.logical_to_physical_map,
|
shared_experts=None,
|
||||||
logical_replica_count=self.logical_replica_count,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
)
|
log2phy=self.log2phy,
|
||||||
|
global_redundant_expert_num=self.global_redundant_expert_num)
|
||||||
|
|
||||||
if isinstance(final_hidden_states, tuple):
|
if isinstance(final_hidden_states, tuple):
|
||||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||||
|
|
||||||
@@ -361,8 +420,3 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
if self.multistream_overlap_shared_expert:
|
if self.multistream_overlap_shared_expert:
|
||||||
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
||||||
return shared_out, fused_output
|
return shared_out, fused_output
|
||||||
|
|
||||||
|
|
||||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
|
||||||
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
|
||||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
|
||||||
|
|||||||
@@ -1,455 +0,0 @@
|
|||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# Copyright 2023 The vLLM team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
# Adapted from vllm/tests/kernels/test_moe.py
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch_npu
|
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
|
||||||
get_tp_group)
|
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import \
|
|
||||||
FusedMoEConfig # isort: skip
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import \
|
|
||||||
FusedMoEParallelConfig # isort: skip
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
|
||||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import \
|
|
||||||
QuantizationConfig
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
||||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|
||||||
determine_default_log2phy_map)
|
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
|
||||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
|
||||||
get_all_reduce_merge_state,
|
|
||||||
get_rm_router_logits_state, is_310p)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
||||||
|
|
||||||
def __init__(self, moe: FusedMoEConfig = None):
|
|
||||||
|
|
||||||
super().__init__(moe=moe)
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
|
|
||||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
|
||||||
get_ascend_config()
|
|
||||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
|
||||||
|
|
||||||
try:
|
|
||||||
device_group = get_mc2_group().device_group
|
|
||||||
# TODO: Try local_rank = ep_group.rank_in_group
|
|
||||||
local_rank = torch.distributed.get_rank(group=device_group)
|
|
||||||
backend = device_group._get_backend(torch.device("npu"))
|
|
||||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
|
||||||
local_rank)
|
|
||||||
except AttributeError:
|
|
||||||
self.moe_all_to_all_group_name = None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
|
||||||
super(UnquantizedFusedMoEMethod,
|
|
||||||
self).process_weights_after_loading(layer)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
|
||||||
layer.w13_weight.data),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
|
||||||
layer.w2_weight.data),
|
|
||||||
requires_grad=False)
|
|
||||||
if not is_310p():
|
|
||||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
|
||||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
|
||||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
scoring_func: str = "softmax",
|
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
is_prefill: bool = False,
|
|
||||||
enable_force_load_balance: bool = False,
|
|
||||||
shared_experts: Optional[Any] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
topk_weights, topk_ids, row_idx = select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
global_num_experts=global_num_experts)
|
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
|
||||||
# this is a naive implementation for experts load balance so as
|
|
||||||
# to avoid accumulating too much tokens on a single rank.
|
|
||||||
# currently it is only activated when doing profile runs.
|
|
||||||
if enable_force_load_balance and not self.use_aclgraph:
|
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
|
||||||
return moe_comm_method.fused_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
row_idx=row_idx,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
shared_experts=shared_experts,
|
|
||||||
need_trans=True,
|
|
||||||
dynamic_eplb=self.dynamic_eplb)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendFusedMoE(FusedMoE):
|
|
||||||
|
|
||||||
# The moe_counter parameter is required during the initialization of EPLB
|
|
||||||
# to identify the current layer index within the MOE model.
|
|
||||||
moe_counter = -1
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_experts: int, # Global number of experts
|
|
||||||
top_k: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
reduce_results: bool = False,
|
|
||||||
renormalize: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
ep_size: Optional[int] = None,
|
|
||||||
dp_size: Optional[int] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
scoring_func: str = "softmax",
|
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
):
|
|
||||||
# TODO: This could not initialize FusedMoE baseclass,
|
|
||||||
# fixme and make __init__() of AscendFusedMoE more clear
|
|
||||||
super().__init__(
|
|
||||||
num_experts=num_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
reduce_results=reduce_results,
|
|
||||||
renormalize=renormalize,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
quant_config=quant_config,
|
|
||||||
tp_size=tp_size,
|
|
||||||
ep_size=ep_size,
|
|
||||||
dp_size=dp_size,
|
|
||||||
prefix=prefix,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
AscendFusedMoE.moe_counter += 1
|
|
||||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
|
|
||||||
self.moe_parallel_config = FusedMoEParallelConfig.make(
|
|
||||||
tp_size_=(tp_size if tp_size is not None else
|
|
||||||
get_tensor_model_parallel_world_size()),
|
|
||||||
dp_size_=(dp_size
|
|
||||||
if dp_size is not None else get_dp_group().world_size),
|
|
||||||
vllm_parallel_config=vllm_config.parallel_config)
|
|
||||||
|
|
||||||
self.top_k = top_k
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.global_num_experts = num_experts
|
|
||||||
assert intermediate_size % self.tp_size == 0
|
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
|
||||||
self.reduce_results = reduce_results
|
|
||||||
self.renormalize = renormalize
|
|
||||||
self.use_grouped_topk = use_grouped_topk
|
|
||||||
if self.use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
self.num_expert_group = num_expert_group
|
|
||||||
self.topk_group = topk_group
|
|
||||||
self.custom_routing_function = custom_routing_function
|
|
||||||
self.scoring_func = scoring_func
|
|
||||||
self.e_score_correction_bias = e_score_correction_bias
|
|
||||||
self.expert_map = None
|
|
||||||
self.activation = activation
|
|
||||||
self.log2phy = None
|
|
||||||
self.global_redundant_expert_num = 0
|
|
||||||
|
|
||||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
|
||||||
self.rm_router_logits = get_rm_router_logits_state(
|
|
||||||
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
|
||||||
self.all_reduce_merge = get_all_reduce_merge_state(
|
|
||||||
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
|
||||||
self.expert_map_path = ascend_config.expert_map_path
|
|
||||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
|
||||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
|
||||||
# static eplb initializing with expert_map_path
|
|
||||||
if self.expert_map_path and os.path.exists(
|
|
||||||
self.expert_map_path) and os.access(self.expert_map_path,
|
|
||||||
os.R_OK):
|
|
||||||
self.expert_load_balancer = ExpertLoadBalancer(
|
|
||||||
self.expert_map_path, self.global_num_experts)
|
|
||||||
self.local_num_experts, self.expert_map = (
|
|
||||||
self.expert_load_balancer.get_rank_placement_map(
|
|
||||||
self.moe_instance_id, self.ep_rank))
|
|
||||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
|
||||||
self.moe_instance_id, self.ep_rank).npu()
|
|
||||||
self.global_redundant_expert_num = (
|
|
||||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
|
||||||
else:
|
|
||||||
# init moe.
|
|
||||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
|
||||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
|
||||||
# dynamic eplb initializing with not expert_map_path
|
|
||||||
if self.dynamic_eplb:
|
|
||||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
|
||||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
||||||
self.global_redundant_expert_num)
|
|
||||||
self.log2phy = determine_default_log2phy_map(
|
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
||||||
self.global_redundant_expert_num)
|
|
||||||
local_num_experts = (torch.sum(self.expert_map != -1)
|
|
||||||
if self.expert_map is not None else num_experts)
|
|
||||||
if self.dynamic_eplb:
|
|
||||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
|
||||||
|
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
|
||||||
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
|
||||||
"non-grouped topk.")
|
|
||||||
moe = FusedMoEConfig(
|
|
||||||
num_experts=self.global_num_experts,
|
|
||||||
experts_per_token=top_k,
|
|
||||||
hidden_dim=hidden_size,
|
|
||||||
num_local_experts=self.local_num_experts,
|
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
|
||||||
in_dtype=params_dtype,
|
|
||||||
)
|
|
||||||
self.moe_config = moe
|
|
||||||
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
|
|
||||||
|
|
||||||
if quant_config is None:
|
|
||||||
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
|
||||||
else:
|
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
|
||||||
|
|
||||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
|
||||||
if self.expert_map is not None else num_experts
|
|
||||||
|
|
||||||
self.moe_load = None
|
|
||||||
|
|
||||||
if self.dynamic_eplb:
|
|
||||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
|
||||||
|
|
||||||
moe_quant_params = {
|
|
||||||
"num_experts": local_num_experts,
|
|
||||||
"hidden_size": hidden_size,
|
|
||||||
"intermediate_size_per_partition":
|
|
||||||
self.intermediate_size_per_partition,
|
|
||||||
"params_dtype": params_dtype,
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
}
|
|
||||||
# need full intermediate size pre-sharding for WNA16 act order
|
|
||||||
if (self.quant_method.__class__.__name__
|
|
||||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
|
||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
|
||||||
|
|
||||||
self.ep_group = get_ep_group()
|
|
||||||
# NOTE: self.tp_group is not expert_tp_group
|
|
||||||
self.tp_group = get_tp_group().device_group
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
|
||||||
|
|
||||||
self.moe_config.tp_group = get_tp_group()
|
|
||||||
self.moe_config.dp_group = get_dp_group()
|
|
||||||
self.moe_config.ep_group = get_ep_group()
|
|
||||||
self.moe_config.mc2_group = get_mc2_group()
|
|
||||||
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
|
|
||||||
|
|
||||||
setup_moe_comm_method(self.moe_config)
|
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
|
||||||
self.expert_map = new_expert_map
|
|
||||||
|
|
||||||
def get_map(self):
|
|
||||||
return self.expert_map
|
|
||||||
|
|
||||||
def get_log2phy_map(self):
|
|
||||||
return self.logical_to_physical_map
|
|
||||||
|
|
||||||
def clear_moe_load(self):
|
|
||||||
if self.moe_load is not None:
|
|
||||||
self.moe_load.zero_()
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
is_prefill: bool,
|
|
||||||
enable_force_load_balance: bool = False,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
shared_experts: Optional[Any] = None,
|
|
||||||
gate=None,
|
|
||||||
replace_allreduce: bool = False):
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
|
||||||
|
|
||||||
if top_k:
|
|
||||||
real_top_k = top_k
|
|
||||||
else:
|
|
||||||
real_top_k = self.top_k
|
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
|
||||||
mc2_mask = forward_context.mc2_mask
|
|
||||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
|
||||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
|
||||||
|
|
||||||
if shared_experts:
|
|
||||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
|
||||||
shared_hidden_states = shared_experts(hidden_states)
|
|
||||||
|
|
||||||
if forward_context.sp_enabled:
|
|
||||||
replace_allreduce = True
|
|
||||||
|
|
||||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
|
||||||
rm_router_logits=self.rm_router_logits,
|
|
||||||
replace_allreduce=replace_allreduce,
|
|
||||||
gate=gate)
|
|
||||||
|
|
||||||
# Matrix multiply.
|
|
||||||
e_hidden_states = self.quant_method.apply(
|
|
||||||
layer=self,
|
|
||||||
x=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=real_top_k,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
global_num_experts=self.global_num_experts,
|
|
||||||
expert_map=self.expert_map,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
scoring_func=self.scoring_func,
|
|
||||||
e_score_correction_bias=self.e_score_correction_bias,
|
|
||||||
is_prefill=is_prefill,
|
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
|
||||||
log2phy=self.log2phy,
|
|
||||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
|
||||||
shared_experts=None,
|
|
||||||
mc2_mask=mc2_mask,
|
|
||||||
quantized_x_for_share=quantized_x_for_share,
|
|
||||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
|
||||||
)
|
|
||||||
|
|
||||||
group_list_type = None
|
|
||||||
|
|
||||||
if shared_experts:
|
|
||||||
if isinstance(e_hidden_states,
|
|
||||||
tuple) and len(e_hidden_states) == 2:
|
|
||||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
|
||||||
|
|
||||||
if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 3:
|
|
||||||
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
|
|
||||||
|
|
||||||
if self.dynamic_eplb and group_list_type is not None:
|
|
||||||
self.moe_load += expert_tokens if group_list_type else \
|
|
||||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
|
||||||
|
|
||||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
|
||||||
hidden_states=e_hidden_states,
|
|
||||||
reduce_results=(not self.all_reduce_merge))
|
|
||||||
|
|
||||||
if shared_experts:
|
|
||||||
return final_hidden_states, shared_hidden_states
|
|
||||||
else:
|
|
||||||
return final_hidden_states
|
|
||||||
|
|
||||||
# ----------------------------------------- TBO-related --------------------------------------------
|
|
||||||
|
|
||||||
def _forward_ms_fused_moe_comp(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
is_prefill: bool,
|
|
||||||
real_top_k,
|
|
||||||
enable_force_load_balance: bool = False,
|
|
||||||
):
|
|
||||||
hidden_states = self.quant_method.apply(
|
|
||||||
layer=self,
|
|
||||||
x=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=real_top_k,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
global_num_experts=self.global_num_experts,
|
|
||||||
expert_map=self.expert_map,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
scoring_func=self.scoring_func,
|
|
||||||
e_score_correction_bias=self.e_score_correction_bias,
|
|
||||||
is_prefill=is_prefill,
|
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
@@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
|
from vllm_ascend.utils import get_rm_router_logits_state
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEPrepareAndFinalize(ABC):
|
class FusedMoEPrepareAndFinalize(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -41,13 +43,16 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
|
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
|
||||||
|
self.rm_router_logits = get_rm_router_logits_state(
|
||||||
|
self.moe_config.ep_size, self.moe_config.dp_size,
|
||||||
|
is_deepseek_v3_r1)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(self,
|
def prepare(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
rm_router_logits: bool = False,
|
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -61,7 +66,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||||
rm_router_logits (bool): Discard input router_logits and recompute via gate
|
|
||||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
||||||
|
|
||||||
@@ -116,7 +120,6 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
rm_router_logits: bool = False,
|
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -215,7 +218,6 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
rm_router_logits: bool = False,
|
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -294,7 +296,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
rm_router_logits: bool = False,
|
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -302,7 +303,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
1. Fetch max token count across DP group from forward context.
|
1. Fetch max token count across DP group from forward context.
|
||||||
2. Pad local tensors to that size.
|
2. Pad local tensors to that size.
|
||||||
3. All-gather across DP group to form global input tensor.
|
3. All-gather across DP group to form global input tensor.
|
||||||
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (global_hidden_states, global_router_logits, None)
|
Tuple of (global_hidden_states, global_router_logits, None)
|
||||||
@@ -318,14 +318,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
|||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
hidden_states = nn.functional.pad(hidden_states,
|
hidden_states = nn.functional.pad(hidden_states,
|
||||||
(0, 0, 0, pad_size))
|
(0, 0, 0, pad_size))
|
||||||
if not rm_router_logits:
|
if not self.rm_router_logits:
|
||||||
router_logits = nn.functional.pad(router_logits,
|
router_logits = nn.functional.pad(router_logits,
|
||||||
(0, 0, 0, pad_size))
|
(0, 0, 0, pad_size))
|
||||||
|
|
||||||
# All-gather across DP group
|
# All-gather across DP group
|
||||||
hidden_states = self.moe_config.dp_group.all_gather(
|
hidden_states = self.moe_config.dp_group.all_gather(
|
||||||
hidden_states, 0)
|
hidden_states, 0)
|
||||||
if rm_router_logits:
|
if self.rm_router_logits:
|
||||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
router_logits, _ = gate(hidden_states) # Recompute globally
|
||||||
else:
|
else:
|
||||||
router_logits = self.moe_config.dp_group.all_gather(
|
router_logits = self.moe_config.dp_group.all_gather(
|
||||||
@@ -399,14 +399,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
rm_router_logits: bool = False,
|
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Fetch cumulative token boundaries from forward context.
|
1. Fetch cumulative token boundaries from forward context.
|
||||||
2. Multicast hidden_states and router_logits to form global tensors.
|
2. Multicast hidden_states and router_logits to form global tensors.
|
||||||
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (global_hidden_states, global_router_logits, None)
|
Tuple of (global_hidden_states, global_router_logits, None)
|
||||||
@@ -418,7 +416,7 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
|||||||
).dp_metadata.cu_tokens_across_sp(1)
|
).dp_metadata.cu_tokens_across_sp(1)
|
||||||
hidden_states = self._naive_multicast(hidden_states,
|
hidden_states = self._naive_multicast(hidden_states,
|
||||||
self.cu_tokens_across_dp_cpu)
|
self.cu_tokens_across_dp_cpu)
|
||||||
if rm_router_logits:
|
if self.rm_router_logits:
|
||||||
router_logits, _ = gate(hidden_states)
|
router_logits, _ = gate(hidden_states)
|
||||||
else:
|
else:
|
||||||
router_logits = self._naive_multicast(
|
router_logits = self._naive_multicast(
|
||||||
|
|||||||
@@ -67,12 +67,11 @@ class MoECommMethod(ABC):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
rm_router_logits: bool = False,
|
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||||
hidden_states, router_logits, enable_shared_expert_dp,
|
hidden_states, router_logits, enable_shared_expert_dp,
|
||||||
rm_router_logits, replace_allreduce, gate)
|
replace_allreduce, gate)
|
||||||
self.mc2_mask = mc2_mask
|
self.mc2_mask = mc2_mask
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
|||||||
@@ -468,9 +468,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.with_quant = False
|
self.with_quant = False
|
||||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||||
self.num_global_redundant_experts = kwargs.get(
|
|
||||||
"num_global_redundant_experts", 0)
|
|
||||||
self.num_experts = self.num_experts + self.num_global_redundant_experts
|
|
||||||
|
|
||||||
self.hidden_shape = None
|
self.hidden_shape = None
|
||||||
self.topk_weights = None
|
self.topk_weights = None
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
|
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||||
oproj_tp_enable)
|
oproj_tp_enable)
|
||||||
|
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ from vllm.sequence import IntermediateTensors
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
|
||||||
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
|
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
|
||||||
init_metadata_for_sp)
|
init_metadata_for_sp)
|
||||||
|
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||||
|
|
||||||
|
|
||||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||||
@@ -81,7 +81,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
|||||||
prefix=f"{prefix}.gate",
|
prefix=f"{prefix}.gate",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = AscendFusedMoE(
|
self.experts = TorchairAscendFusedMoE(
|
||||||
num_experts=config.num_experts,
|
num_experts=config.num_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
|
|||||||
@@ -880,7 +880,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
# this is a naive implementation for experts load balance so as
|
# this is a naive implementation for experts load balance so as
|
||||||
# to avoid accumulating too much tokens on a single rank.
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance and not self.use_aclgraph:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
fused_moe_state = get_forward_context().fused_moe_state
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
|
|||||||
Reference in New Issue
Block a user