[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -245,6 +245,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"}) @patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
@wait_until_npu_memory_free()
def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep(): def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
short_example_prompts = [ short_example_prompts = [
"Hello ", "Hello ",
@@ -272,6 +273,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"}) @patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
@wait_until_npu_memory_free()
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep(): def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
short_example_prompts = [ short_example_prompts = [
"Hello ", "Hello ",

View File

@@ -28,11 +28,17 @@ import torch
import torch_npu import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe.experts_selector import ( from vllm_ascend.ops.fused_moe.experts_selector import check_npu_moe_gating_top_k, select_experts
check_npu_moe_gating_top_k, select_experts)
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.token_dispatcher import \ from vllm_ascend.ops.fused_moe.moe_runtime_args import (
TokenDispatcherWithAllGather build_fused_experts_input,
build_mlp_compute_input,
MoEQuantParams,
MoERoutingParams,
MoETokenDispatchInput,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather
from vllm_ascend.quantization.quant_type import QuantType
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
EP_SIZE = [1] EP_SIZE = [1]
@@ -83,10 +89,8 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
for i in range(w1.shape[0]): for i in range(w1.shape[0]):
mask = topk_ids == i mask = topk_ids == i
if mask.sum(): if mask.sum():
out[mask] = SiluAndMul()( out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
return (out.view(B, -1, w2.shape[1]) *
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@@ -129,36 +133,41 @@ def test_token_dispatcher_with_all_gather(
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch( token_dispatch_output = dispatcher.token_dispatch(
hidden_states=a, token_dispatch_input=MoETokenDispatchInput(
topk_weights=topk_weights, hidden_states=a,
topk_ids=topk_ids, topk_weights=topk_weights,
expert_map=expert_map, topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input) routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=0,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
),
quant=MoEQuantParams(quant_type=QuantType.NONE),
)
)
sorted_hidden_states = dispatch_output.hidden_states sorted_hidden_states = token_dispatch_output.hidden_states
group_list = dispatch_output.group_list group_list = token_dispatch_output.group_list
group_list_type = dispatch_output.group_list_type group_list_type = token_dispatch_output.group_list_type
context_metadata = dispatch_output.context_metadata combine_metadata = token_dispatch_output.combine_metadata
expert_output = apply_mlp(hidden_states=sorted_hidden_states, expert_output = apply_mlp(
w1=w1_local, hidden_states=sorted_hidden_states,
w2=w2_local, w1=w1_local,
group_list=group_list, w2=w2_local,
group_list_type=group_list_type) group_list=group_list,
group_list_type=group_list_type,
)
combined_output = dispatcher.token_combine( combined_output = dispatcher.token_combine(
hidden_states=expert_output, hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
context_metadata=context_metadata, )
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map)
expert_map)
torch.testing.assert_close(combined_output.routed_out, torch.testing.assert_close(combined_output, torch_output, atol=4e-2, rtol=1)
torch_output,
atol=4e-2,
rtol=1)
gc.collect() gc.collect()
torch.npu.empty_cache() torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats() torch.npu.reset_peak_memory_stats()
@@ -184,8 +193,7 @@ def test_token_dispatcher_with_all_gather_quant(
): ):
context_mock = MagicMock() context_mock = MagicMock()
context_mock.fused_moe_state = 0 context_mock.fused_moe_state = 0
with patch("vllm_ascend.ascend_forward_context.get_forward_context", with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", return_value=context_mock):
return_value=context_mock):
a = torch.randn((m, k), device=device, dtype=dtype) / 10 a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype) w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
@@ -208,34 +216,44 @@ def test_token_dispatcher_with_all_gather_quant(
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch( token_dispatch_output = dispatcher.token_dispatch(
hidden_states=a, token_dispatch_input=MoETokenDispatchInput(
topk_weights=topk_weights, hidden_states=a,
topk_ids=topk_ids, topk_weights=topk_weights,
expert_map=expert_map, topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input, routing=MoERoutingParams(
with_quant=True) expert_map=expert_map,
global_redundant_expert_num=0,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
),
quant=MoEQuantParams(quant_type=QuantType.W8A8),
)
)
sorted_hidden_states = dispatch_output.hidden_states combine_metadata = token_dispatch_output.combine_metadata
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
dynamic_scale = dispatch_output.dynamic_scale
context_metadata = dispatch_output.context_metadata
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, mlp_compute_input = build_mlp_compute_input(
w1=w1, fused_experts_input=build_fused_experts_input(
w1_scale=w1_scale, hidden_states=a,
w2=w2, topk_weights=topk_weights,
w2_scale=w2_scale, topk_ids=topk_ids,
group_list=group_list, w1=w1,
group_list_type=group_list_type, w2=w2,
dynamic_scale=dynamic_scale, quant_type=QuantType.W8A8,
with_quant=True) dynamic_eplb=False,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
),
token_dispatch_output=token_dispatch_output,
use_fusion_ops=False,
)
expert_output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
combined_output = dispatcher.token_combine( combined_output = dispatcher.token_combine(
hidden_states=expert_output, hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
context_metadata=context_metadata, )
bias=None) assert combined_output.shape == (m, k)
assert combined_output.routed_out.shape == (m, k)
gc.collect() gc.collect()
torch.npu.empty_cache() torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats() torch.npu.reset_peak_memory_stats()
@@ -271,25 +289,20 @@ def test_select_experts(
hidden_states = torch.randn(m, n, device=device, dtype=dtype) hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype) router_logits = torch.randn(m, e, device=device, dtype=dtype)
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype) e_score_correction_bias = torch.randn(e, device=device, dtype=dtype) if with_e_correction else None
if with_e_correction else None)
custom_routing_function = None custom_routing_function = None
if custom_routing: if custom_routing:
custom_routing_function = MagicMock() custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype) mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0, mock_ids = torch.randint(0, e, (m, topk), device=device, dtype=torch.int32)
e, (m, topk),
device=device,
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids) custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk" with (
) as mock_native_grouped_topk, \ patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk") as mock_native_grouped_topk,
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
return_value=MagicMock()): ):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(x)
x)
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -305,8 +318,8 @@ def test_select_experts(
) )
call_moe_gatingtopk = check_npu_moe_gating_top_k( call_moe_gatingtopk = check_npu_moe_gating_top_k(
hidden_states, topk, renormalize, topk_group, num_expert_group, hidden_states, topk, renormalize, topk_group, num_expert_group, scoring_func, custom_routing_function
scoring_func, custom_routing_function) )
if not call_moe_gatingtopk and use_grouped_topk: if not call_moe_gatingtopk and use_grouped_topk:
mock_native_grouped_topk.assert_called_once() mock_native_grouped_topk.assert_called_once()
else: else:
@@ -323,16 +336,18 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE) @pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str): def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', with (
return_value=MagicMock()), \ patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
pytest.raises(ValueError, pytest.raises(ValueError, match="Unsupported scoring function: invalid"),
match="Unsupported scoring function: invalid"): ):
select_experts(hidden_states=torch.randn(1, 128, device=device), select_experts(
router_logits=torch.randn(1, 8, device=device), hidden_states=torch.randn(1, 128, device=device),
top_k=2, router_logits=torch.randn(1, 8, device=device),
use_grouped_topk=False, top_k=2,
renormalize=False, use_grouped_topk=False,
scoring_func="invalid") renormalize=False,
scoring_func="invalid",
)
gc.collect() gc.collect()
torch.npu.empty_cache() torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats() torch.npu.reset_peak_memory_stats()

View File

@@ -19,6 +19,38 @@ import torch
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.quantization.quant_type import QuantType
def build_mlp_compute_input_fixture(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
with_quant: bool,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
group_list_type: int = 1,
) -> MoEMlpComputeInput:
return MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=group_list,
group_list_type=group_list_type,
dynamic_scale=None,
topk_scales=None,
weights=MoEWeights(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale),
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
fusion=False,
activation="silu",
need_trans=False,
dynamic_eplb=False,
)
class TestUnifiedApplyMLP310(TestBase): class TestUnifiedApplyMLP310(TestBase):
@@ -38,14 +70,13 @@ class TestUnifiedApplyMLP310(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp( result = unified_apply_mlp(
hidden_states=hidden_states, mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=None, w1=w1,
w2=w2, w2=w2,
w2_scale=None, group_list=group_list,
group_list=group_list, with_quant=False,
group_list_type=1, )
with_quant=False,
) )
self.assertEqual(mock_npu_grouped_matmul.call_count, 2) self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
@@ -94,14 +125,15 @@ class TestUnifiedApplyMLP310(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp( result = unified_apply_mlp(
hidden_states=hidden_states, mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=w1_scale, w1=w1,
w2=w2, w2=w2,
w2_scale=w2_scale, group_list=group_list,
group_list=group_list, with_quant=True,
group_list_type=1, w1_scale=w1_scale,
with_quant=True, w2_scale=w2_scale,
)
) )
mock_cumsum.assert_called_once() mock_cumsum.assert_called_once()

View File

@@ -95,4 +95,4 @@ def test_SiluAndMul_forward_310p(
assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input" assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input"
expected_out = (dummy_tensor[..., :h] + 1) * dummy_tensor[..., h:] expected_out = (dummy_tensor[..., :h] + 1) * dummy_tensor[..., h:]
assert torch.allclose(out, expected_out) assert torch.allclose(out, expected_out)

View File

@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from typing import List, TypedDict from typing import TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -20,12 +20,19 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch_npu import torch_npu
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list, from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
unified_apply_mlp) from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEPrepareOutput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import AscendDeviceType, adapt_patch from vllm_ascend.utils import AscendDeviceType, adapt_patch
adapt_patch(True) adapt_patch(True)
@@ -54,6 +61,51 @@ def mock_npu_format_cast(weight_data, format):
return weight_data return weight_data
def build_mlp_compute_input_fixture(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
with_quant: bool,
group_list_type: int = 1,
dynamic_scale: torch.Tensor | None = None,
topk_scales: torch.Tensor | None = None,
w1_scale: torch.Tensor | list[torch.Tensor] | None = None,
w2_scale: torch.Tensor | list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor | None = None,
w2_scale_bias: torch.Tensor | None = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
fusion: bool = False,
activation: str = "silu",
need_trans: bool = True,
dynamic_eplb: bool = False,
) -> MoEMlpComputeInput:
return MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=group_list,
group_list_type=group_list_type,
dynamic_scale=dynamic_scale,
topk_scales=topk_scales,
weights=MoEWeights(
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
),
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
fusion=fusion,
activation=activation,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb,
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture): def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock() mock_hf_config = MagicMock()
@@ -77,7 +129,13 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method = MagicMock() mock_moe_comm_method = MagicMock()
def mock_prepare(hidden_states, router_logits, **kwargs): def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=kwargs.get("mc2_mask"),
padded_hidden_states_shape=None,
pertoken_scale=None,
)
mock_moe_comm_method.prepare.side_effect = mock_prepare mock_moe_comm_method.prepare.side_effect = mock_prepare
@@ -204,18 +262,18 @@ def moe_method(mock_dist_env):
class Device(TypedDict): class Device(TypedDict):
device_id: int device_id: int
device_expert: List[int] device_expert: list[int]
class Layer(TypedDict): class Layer(TypedDict):
layer_id: int layer_id: int
device_count: int device_count: int
device_list: List[Device] device_list: list[Device]
class MockData(TypedDict): class MockData(TypedDict):
moe_layer_count: int moe_layer_count: int
layer_list: List[Layer] layer_list: list[Layer]
class MockQuantMethod(nn.Module): class MockQuantMethod(nn.Module):
@@ -338,18 +396,15 @@ class TestUnifiedApplyMLP(TestBase):
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16) w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states, result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=w1_scale, w1=w1,
w2=w2, w2=w2,
w2_scale=w2_scale, group_list=group_list,
group_list=group_list, with_quant=True,
dynamic_scale=None, w1_scale=w1_scale,
group_list_type=1, w2_scale=w2_scale,
w1_scale_bias=None, ))
w2_scale_bias=None,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called() mock_get_forward_context.assert_called()
@@ -383,18 +438,14 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16) topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states, result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=None, w1=w1,
w2=w2, w2=w2,
w2_scale=None, group_list=group_list,
group_list=group_list, with_quant=False,
dynamic_scale=None, topk_scales=topk_scales,
group_list_type=1, ))
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2) self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once() mock_npu_swiglu.assert_called_once()
@@ -445,18 +496,18 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states, result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=w1_scale, w1=w1,
w2=w2, w2=w2,
w2_scale=w2_scale, group_list=group_list,
group_list=group_list, with_quant=True,
dynamic_scale=provided_dynamic_scale, dynamic_scale=provided_dynamic_scale,
group_list_type=1, w1_scale=w1_scale,
w1_scale_bias=w1_scale_bias, w2_scale=w2_scale,
w2_scale_bias=w2_scale_bias, w1_scale_bias=w1_scale_bias,
topk_scales=None, w2_scale_bias=w2_scale_bias,
with_quant=True) ))
mock_get_forward_context.assert_called() mock_get_forward_context.assert_called()
@@ -490,18 +541,14 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16) topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states, result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=None, w1=w1,
w2=w2, w2=w2,
w2_scale=None, group_list=group_list,
group_list=group_list, with_quant=False,
dynamic_scale=None, topk_scales=topk_scales,
group_list_type=1, ))
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2) self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once() mock_npu_swiglu.assert_called_once()
@@ -556,19 +603,19 @@ class TestUnifiedApplyMLP(TestBase):
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states, result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
w1=w1, hidden_states=hidden_states,
w1_scale=w1_scale, w1=w1,
w2=w2, w2=w2,
w2_scale=w2_scale, group_list=group_list,
group_list=group_list, with_quant=True,
dynamic_scale=provided_dynamic_scale, dynamic_scale=provided_dynamic_scale,
group_list_type=1, w1_scale=w1_scale,
w1_scale_bias=w1_scale_bias, w2_scale=w2_scale,
w2_scale_bias=w2_scale_bias, w1_scale_bias=w1_scale_bias,
topk_scales=None, w2_scale_bias=w2_scale_bias,
with_quant=True, fusion=True,
fusion=True) ))
mock_get_forward_context.assert_called() mock_get_forward_context.assert_called()
mock_npu_grouped_matmul.assert_called_once() mock_npu_grouped_matmul.assert_called_once()

View File

@@ -4,12 +4,21 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, from vllm_ascend.ops.fused_moe.moe_comm_method import (
AlltoAllCommImpl, AllGatherCommImpl,
MC2CommImpl) AlltoAllCommImpl,
MC2CommImpl,
)
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoEFusedExpertsInput,
MoEPrepareOutput,
MoEQuantParams,
MoERoutingParams,
MoEWeights,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)
class TestMoECommMethod(TestBase): class TestMoECommMethod(TestBase):
@@ -45,8 +54,11 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize # Mock prepare finalize
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = MoEPrepareOutput(
torch.randn(4, 2), None, None) hidden_states=torch.randn(4, 8),
router_logits=torch.randn(4, 2),
mc2_mask=None,
padded_hidden_states_shape=None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
@@ -60,8 +72,9 @@ class TestMoECommMethod(TestBase):
# Test prepare method # Test prepare method
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( prepare_output = comm_impl.prepare(hidden_states, router_logits)
hidden_states, router_logits) h_out = prepare_output.hidden_states
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
@@ -70,7 +83,7 @@ class TestMoECommMethod(TestBase):
# Test finalize method # Test finalize method
comm_impl.finalize(h_out, comm_impl.finalize(h_out,
reduce_results=True, reduce_results=True,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@@ -86,10 +99,11 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize # Mock prepare finalize
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = MoEPrepareOutput(
torch.randn(4, 2), hidden_states=torch.randn(4, 8),
torch.tensor([1, 0, 1, router_logits=torch.randn(4, 2),
0]), None) mc2_mask=torch.tensor([1, 0, 1, 0]),
padded_hidden_states_shape=None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
@@ -103,8 +117,9 @@ class TestMoECommMethod(TestBase):
# Test prepare method # Test prepare method
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( prepare_output = comm_impl.prepare(hidden_states, router_logits)
hidden_states, router_logits) h_out = prepare_output.hidden_states
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
@@ -113,7 +128,7 @@ class TestMoECommMethod(TestBase):
# Test finalize method # Test finalize method
comm_impl.finalize(h_out, comm_impl.finalize(h_out,
reduce_results=True, reduce_results=True,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@@ -133,8 +148,11 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize # Mock prepare finalize
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = MoEPrepareOutput(
torch.randn(4, 2), None, None) hidden_states=torch.randn(4, 8),
router_logits=torch.randn(4, 2),
mc2_mask=None,
padded_hidden_states_shape=None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
@@ -148,8 +166,7 @@ class TestMoECommMethod(TestBase):
# Test prepare method # Test prepare method
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( _ = comm_impl.prepare(hidden_states, router_logits)
hidden_states, router_logits)
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
@@ -174,19 +191,27 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize # Mock prepare finalize
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = MoEPrepareOutput(
torch.randn(4, 2), None) hidden_states=torch.randn(4, 8),
router_logits=torch.randn(4, 2),
mc2_mask=None,
padded_hidden_states_shape=None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher # Mock token dispatcher
mock_td_instance = MagicMock() mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = TokenDispatchResult( dispatch_topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]])
hidden_states=torch.randn(6, 8), mock_td_instance.token_dispatch.return_value = MoETokenDispatchOutput(
group_list=torch.tensor([2, 2, 2]), hidden_states=torch.randn(6, 8),
group_list_type=1) group_list=torch.tensor([2, 2, 2]),
mock_td_instance.token_combine.return_value = TokenCombineResult( group_list_type=1,
routed_out=torch.randn(4, 8)) combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=dispatch_topk_weights,
expanded_row_idx=torch.arange(8, dtype=torch.int32),
restore_shape=torch.Size([4, 8]),
))
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
mock_token_dispatcher.return_value = mock_td_instance mock_token_dispatcher.return_value = mock_td_instance
# Mock unified_apply_mlp # Mock unified_apply_mlp
@@ -199,8 +224,7 @@ class TestMoECommMethod(TestBase):
hidden_states = torch.randn(4, 8).contiguous() hidden_states = torch.randn(4, 8).contiguous()
w1 = torch.randn(16, 8).contiguous() w1 = torch.randn(16, 8).contiguous()
w2 = torch.randn(16, 8).contiguous() w2 = torch.randn(16, 8).contiguous()
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], topk_weights = dispatch_topk_weights
[0.6, 0.4]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
# Make sure tensors are contiguous and have correct strides # Make sure tensors are contiguous and have correct strides
@@ -208,12 +232,25 @@ class TestMoECommMethod(TestBase):
w1 = w1.contiguous() w1 = w1.contiguous()
w2 = w2.contiguous() w2 = w2.contiguous()
result = comm_impl.fused_experts(hidden_states=hidden_states, result = comm_impl.fused_experts(fused_experts_input=MoEFusedExpertsInput(
w1=[w1], hidden_states=hidden_states,
w2=[w2], topk_weights=topk_weights,
topk_weights=topk_weights, topk_ids=topk_ids,
topk_ids=topk_ids, weights=MoEWeights(
activation="silu") w1=[w1],
w2=[w2],
),
routing=MoERoutingParams(
expert_map=None,
global_redundant_expert_num=0,
mc2_mask=None,
apply_router_weight_on_input=False,
),
activation="silu",
need_trans=False,
dynamic_eplb=False,
quant=MoEQuantParams(),
))
# Verify result shape # Verify result shape
self.assertEqual(result.routed_out.shape, (4, 8)) self.assertEqual(result.routed_out.shape, (4, 8))
@@ -223,6 +260,12 @@ class TestMoECommMethod(TestBase):
# Verify unified_apply_mlp was called # Verify unified_apply_mlp was called
mock_unified_apply_mlp.assert_called_once() mock_unified_apply_mlp.assert_called_once()
mlp_compute_input = mock_unified_apply_mlp.call_args.kwargs["mlp_compute_input"]
self.assertFalse(mlp_compute_input.fusion)
self.assertFalse(mlp_compute_input.quant.is_mxfp)
# Verify token_combine was called # Verify token_combine was called
mock_td_instance.token_combine.assert_called_once() mock_td_instance.token_combine.assert_called_once_with(
hidden_states=mock_unified_apply_mlp.return_value,
combine_metadata=mock_td_instance.token_dispatch.return_value.combine_metadata,
)

View File

@@ -1,9 +1,17 @@
import unittest import unittest
from typing import ClassVar from typing import ClassVar
from unittest.mock import patch
import torch import torch
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEMlpComputeInput,
MoEQuantParams,
MoEWeights,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
from vllm_ascend.quantization.quant_type import QuantType
class TestCumsumGroupList(unittest.TestCase): class TestCumsumGroupList(unittest.TestCase):
@@ -14,7 +22,7 @@ class TestCumsumGroupList(unittest.TestCase):
cls.glist_dict = { cls.glist_dict = {
0: torch.tensor([0, 2, 3, 3]), 0: torch.tensor([0, 2, 3, 3]),
1: torch.tensor([0, 2, 1, 0]), 1: torch.tensor([0, 2, 1, 0]),
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]) 2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]),
} }
support_combine = [(0, 0), (1, 0), (0, 1)] support_combine = [(0, 0), (1, 0), (0, 1)]
@@ -23,29 +31,101 @@ class TestCumsumGroupList(unittest.TestCase):
def test_cumsum_group_list_supported_conversion(self): def test_cumsum_group_list_supported_conversion(self):
for src_list_type, dst_list_type in self.support_combine: for src_list_type, dst_list_type in self.support_combine:
with self.subTest(src=src_list_type, dst=dst_list_type): with self.subTest(src=src_list_type, dst=dst_list_type):
result = cumsum_group_list(self.glist_dict[src_list_type], result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4)
src_list_type, self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type]))
dst_list_type,
expert_num=4)
self.assertTrue(
torch.equal(result, self.glist_dict[dst_list_type]))
def test_cumsum_group_list_invalid_type_valueerror(self): def test_cumsum_group_list_invalid_type_valueerror(self):
with self.assertRaises(ValueError) as excinfo: with self.assertRaises(ValueError) as excinfo:
cumsum_group_list(self.glist_dict[0], 4, 0) cumsum_group_list(self.glist_dict[0], 4, 0)
self.assertIn("group_list_type should be in [0, 1, 2], but received", self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception))
str(excinfo.exception))
def test_cumsum_group_list_unsupported_conversion_notimplementederror( def test_cumsum_group_list_unsupported_conversion_notimplementederror(self):
self):
for src_list_type, dst_list_type in self.unsupported_combine: for src_list_type, dst_list_type in self.unsupported_combine:
with self.subTest(src=src_list_type, dst=dst_list_type): with self.subTest(src=src_list_type, dst=dst_list_type):
with self.assertRaises(NotImplementedError) as excinfo: with self.assertRaises(NotImplementedError) as excinfo:
cumsum_group_list(self.glist_dict[0], src_list_type, cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type)
dst_list_type) self.assertIn("This feature is under development.", str(excinfo.exception))
self.assertIn("This feature is under development.",
str(excinfo.exception))
if __name__ == '__main__': class TestUnifiedApplyMlpRequest(unittest.TestCase):
def test_request_unquant_path(self):
hidden_states = torch.randn(2, 8)
expected = torch.randn(2, 8)
mlp_compute_input = MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=None,
topk_scales=None,
weights=MoEWeights(
w1=torch.randn(1, 16, 8),
w2=torch.randn(1, 8, 8),
w1_bias=torch.randn(1, 16),
w2_bias=torch.randn(1, 8),
),
quant=MoEQuantParams(quant_type=QuantType.NONE),
fusion=False,
activation="silu",
need_trans=False,
dynamic_eplb=False,
)
with (
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant,
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant,
):
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
self.assertTrue(output is expected)
mock_unquant.assert_called_once()
self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu")
self.assertFalse(mock_unquant.call_args.kwargs["need_trans"])
mock_quant.assert_not_called()
def test_request_quant_path(self):
hidden_states = torch.randn(2, 8)
expected = torch.randn(2, 8)
mlp_compute_input = MoEMlpComputeInput(
hidden_states=hidden_states,
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=torch.randn(2, 1),
topk_scales=None,
weights=MoEWeights(
w1=torch.randn(1, 16, 8),
w2=torch.randn(1, 8, 8),
w1_scale=[torch.randn(1)],
w2_scale=[torch.randn(1)],
),
quant=MoEQuantParams(
quant_type=QuantType.MXFP8,
mxfp=MoEMxfpParams(
act_quant_type=torch.float8_e4m3fn,
weight_quant_type=torch.float8_e4m3fn,
use_bf16=False,
),
),
fusion=True,
activation="silu",
need_trans=False,
dynamic_eplb=True,
)
with (
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant,
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant,
):
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
self.assertTrue(output is expected)
mock_quant.assert_called_once()
quant_kwargs = mock_quant.call_args.kwargs
self.assertTrue(quant_kwargs["use_mxfp_quant"])
self.assertTrue(quant_kwargs["fusion"])
self.assertTrue(quant_kwargs["dynamic_eplb"])
self.assertFalse(quant_kwargs["use_bf16"])
mock_unquant.assert_not_called()
if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@@ -0,0 +1,240 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import torch
import vllm_ascend.ops.fused_moe.moe_runtime_args as runtime_args
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoETokenDispatchOutput,
MoEWeights,
build_fused_experts_input,
build_mlp_compute_input,
build_token_dispatch_input,
)
from vllm_ascend.quantization.quant_type import QuantType
class TestMoERuntimeArgs(unittest.TestCase):
def test_runtime_args_facade_exports_public_contracts_and_builders(self):
expected_symbols = [
"MoEAllGatherCombineMetadata",
"MoEAllToAllCombineMetadata",
"MoEFusedExpertsInput",
"MoEMC2CombineMetadata",
"MoEMlpComputeInput",
"MoEPrepareOutput",
"MoEQuantParams",
"MoERoutingParams",
"MoETokenDispatchInput",
"MoETokenDispatchOutput",
"MoEWeights",
"TMoECombineMetadata",
"build_fused_experts_input",
"build_mlp_compute_input",
"build_token_dispatch_input",
]
for symbol in expected_symbols:
with self.subTest(symbol=symbol):
self.assertTrue(hasattr(runtime_args, symbol))
self.assertFalse(hasattr(runtime_args, "MoEMxfpParams"))
def test_build_fused_experts_input_preserves_runtime_semantics(self):
for quant_type in (
QuantType.NONE,
QuantType.W4A16,
QuantType.W4A8,
QuantType.W8A8,
QuantType.MXFP8,
):
with self.subTest(quant_type=quant_type):
hidden_states = torch.randn(4, 8)
topk_weights = torch.randn(4, 2)
topk_ids = torch.randint(0, 4, (4, 2), dtype=torch.int32)
fused_experts_input = build_fused_experts_input(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=torch.randn(2, 8, 16),
w2=torch.randn(2, 16, 8),
quant_type=quant_type,
dynamic_eplb=True,
expert_map=torch.tensor([0, 1, 2, 3], dtype=torch.int32),
global_redundant_expert_num=2,
mc2_mask=torch.tensor([True, False, True, False]),
apply_router_weight_on_input=True,
log2phy=torch.tensor([3, 2, 1, 0], dtype=torch.int32),
pertoken_scale=torch.randn(4),
activation="gelu",
mxfp_act_quant_type=torch.float8_e4m3fn if quant_type == QuantType.MXFP8 else None,
)
self.assertIs(fused_experts_input.hidden_states, hidden_states)
self.assertIs(fused_experts_input.topk_weights, topk_weights)
self.assertIs(fused_experts_input.topk_ids, topk_ids)
self.assertTrue(fused_experts_input.dynamic_eplb)
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
self.assertEqual(fused_experts_input.routing.global_redundant_expert_num, 2)
self.assertEqual(fused_experts_input.activation, "gelu")
self.assertEqual(fused_experts_input.quant.quant_type, quant_type)
def test_build_fused_experts_input_merges_dense_and_quant_weights(self):
w1 = torch.randn(2, 8, 16)
w2 = torch.randn(2, 16, 8)
w1_scale = [torch.randn(1)]
w2_scale = [torch.randn(1)]
w1_scale_bias = torch.randn(1)
w2_scale_bias = torch.randn(1)
w1_offset = torch.randn(1)
w2_offset = torch.randn(1)
fused_experts_input = build_fused_experts_input(
hidden_states=torch.randn(4, 8),
topk_weights=torch.randn(4, 2),
topk_ids=torch.randint(0, 4, (4, 2), dtype=torch.int32),
w1=w1,
w2=w2,
quant_type=QuantType.W8A8,
dynamic_eplb=False,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
)
self.assertIsInstance(fused_experts_input.weights, MoEWeights)
self.assertIs(fused_experts_input.weights.w1, w1)
self.assertIs(fused_experts_input.weights.w2, w2)
self.assertIs(fused_experts_input.weights.w1_scale, w1_scale)
self.assertIs(fused_experts_input.weights.w2_scale, w2_scale)
self.assertIs(fused_experts_input.weights.w1_scale_bias, w1_scale_bias)
self.assertIs(fused_experts_input.weights.w2_scale_bias, w2_scale_bias)
self.assertIs(fused_experts_input.weights.w1_offset, w1_offset)
self.assertIs(fused_experts_input.weights.w2_offset, w2_offset)
def test_build_token_dispatch_input_supports_remapped_topk_ids(self):
fused_experts_input = build_fused_experts_input(
hidden_states=torch.randn(2, 4),
topk_weights=torch.randn(2, 1),
topk_ids=torch.tensor([[0], [1]], dtype=torch.int32),
w1=torch.randn(1, 4, 8),
w2=torch.randn(1, 8, 4),
quant_type=QuantType.NONE,
dynamic_eplb=False,
)
routed_topk_ids = torch.tensor([[3], [2]], dtype=torch.int32)
token_dispatch_input = build_token_dispatch_input(
fused_experts_input=fused_experts_input,
topk_ids=routed_topk_ids,
)
self.assertIs(token_dispatch_input.hidden_states, fused_experts_input.hidden_states)
self.assertIs(token_dispatch_input.topk_weights, fused_experts_input.topk_weights)
self.assertIs(token_dispatch_input.routing, fused_experts_input.routing)
self.assertIs(token_dispatch_input.quant, fused_experts_input.quant)
self.assertIs(token_dispatch_input.topk_ids, routed_topk_ids)
def test_build_fused_experts_input_requires_primitive_mxfp_params_for_mxfp_quant(self):
with self.assertRaisesRegex(ValueError, "primitive MXFP params are required"):
build_fused_experts_input(
hidden_states=torch.randn(2, 8),
topk_weights=torch.randn(2, 2),
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
w1=torch.randn(2, 8, 16),
w2=torch.randn(2, 16, 8),
quant_type=QuantType.MXFP8,
dynamic_eplb=False,
)
def test_build_mlp_compute_input_derives_fusion_and_preserves_mxfp_params(self):
fused_experts_input = build_fused_experts_input(
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
topk_weights=torch.randn(2, 2),
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
w1=torch.randn(2, 8, 16),
w2=torch.randn(2, 16, 8),
quant_type=QuantType.MXFP8,
dynamic_eplb=False,
mxfp_act_quant_type=torch.float8_e4m3fn,
mxfp_weight_quant_type=torch.float8_e4m3fn,
mxfp_scale_dtype=torch.float32,
mxfp_per_token_scale_dtype=torch.float16,
mxfp_use_bf16=False,
w1_scale=[torch.randn(1)],
w2_scale=[torch.randn(1)],
)
token_dispatch_output = MoETokenDispatchOutput(
hidden_states=torch.randn(4, 8, dtype=torch.bfloat16),
group_list=torch.tensor([2, 2], dtype=torch.int64),
group_list_type=1,
dynamic_scale=torch.randn(4, 1),
combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=fused_experts_input.topk_weights,
expanded_row_idx=torch.arange(4, dtype=torch.int32),
restore_shape=torch.Size([2, 8]),
),
)
mlp_compute_input = build_mlp_compute_input(
fused_experts_input=fused_experts_input,
token_dispatch_output=token_dispatch_output,
use_fusion_ops=True,
)
self.assertIs(mlp_compute_input.hidden_states, token_dispatch_output.hidden_states)
self.assertIs(mlp_compute_input.weights, fused_experts_input.weights)
self.assertIs(mlp_compute_input.weights.w1_scale, fused_experts_input.weights.w1_scale)
self.assertIs(mlp_compute_input.weights.w2_scale, fused_experts_input.weights.w2_scale)
self.assertTrue(mlp_compute_input.fusion)
self.assertTrue(mlp_compute_input.quant.is_mxfp)
assert mlp_compute_input.quant.mxfp is not None
self.assertEqual(mlp_compute_input.quant.mxfp.scale_dtype, torch.float32)
self.assertEqual(mlp_compute_input.quant.mxfp.per_token_scale_dtype, torch.float16)
self.assertFalse(mlp_compute_input.quant.mxfp.use_bf16)
def test_build_fused_experts_input_constructs_internal_mxfp_leaf_from_primitives(self):
fused_experts_input = build_fused_experts_input(
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
topk_weights=torch.randn(2, 2),
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
w1=torch.randn(2, 8, 16),
w2=torch.randn(2, 16, 8),
quant_type=QuantType.MXFP8,
dynamic_eplb=False,
mxfp_act_quant_type=torch.float8_e4m3fn,
mxfp_weight_quant_type=torch.float8_e4m3fn,
mxfp_scale_dtype=torch.float32,
mxfp_per_token_scale_dtype=torch.float16,
mxfp_use_bf16=False,
)
self.assertTrue(fused_experts_input.quant.is_mxfp)
assert fused_experts_input.quant.mxfp is not None
self.assertEqual(fused_experts_input.quant.mxfp.act_quant_type, torch.float8_e4m3fn)
self.assertEqual(fused_experts_input.quant.mxfp.weight_quant_type, torch.float8_e4m3fn)
self.assertEqual(fused_experts_input.quant.mxfp.scale_dtype, torch.float32)
self.assertEqual(fused_experts_input.quant.mxfp.per_token_scale_dtype, torch.float16)
self.assertFalse(fused_experts_input.quant.mxfp.use_bf16)
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -45,18 +45,22 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, mask, context_metadata = layer.prepare( prepare_output = layer.prepare(hidden_states, router_logits)
hidden_states, router_logits) h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
mask = prepare_output.mc2_mask
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Check padding and split # Check padding and split
self.assertEqual(h_out.shape[0], 4) self.assertEqual(h_out.shape[0], 4)
self.assertEqual(r_out.shape[0], 4) self.assertEqual(r_out.shape[0], 4)
self.assertEqual(mask.tolist(), [1, 0, 1]) self.assertEqual(mask.tolist(), [1, 0, 1])
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
# Finalize # Finalize
result = layer.finalize(h_out, result = layer.finalize(h_out,
reduce_results=False, reduce_results=False,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@patch( @patch(
@@ -79,14 +83,19 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(4, 8) hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2) router_logits = torch.randn(4, 2)
h_out, r_out, mask, context_metadata = layer.prepare( prepare_output = layer.prepare(
hidden_states, hidden_states,
router_logits, router_logits,
enable_shared_expert_dp=False, enable_shared_expert_dp=False,
replace_allreduce=False) replace_allreduce=False)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
mask = prepare_output.mc2_mask
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# With TP=2, should split into 2 parts # With TP=2, should split into 2 parts
self.assertEqual(h_out.shape[0], 2) self.assertEqual(h_out.shape[0], 2)
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
# Mock all_gather behavior # Mock all_gather behavior
def mock_all_gather_func(tensor_list, tensor, group=None): def mock_all_gather_func(tensor_list, tensor, group=None):
@@ -101,7 +110,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
] ]
final_result = layer.finalize(h_out, final_result = layer.finalize(h_out,
reduce_results=False, reduce_results=False,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
# Should concat back to original size # Should concat back to original size
self.assertEqual(final_result.shape[0], 4) self.assertEqual(final_result.shape[0], 4)
@@ -117,15 +126,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, _, context_metadata = layer.prepare( prepare_output = layer.prepare(hidden_states, router_logits)
hidden_states, router_logits) h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Pad to tp_size=1, so no change # Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3) self.assertEqual(h_out.shape[0], 3)
self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8]))
result = layer.finalize(h_out, result = layer.finalize(h_out,
reduce_results=False, reduce_results=False,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@patch( @patch(
@@ -141,14 +153,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(2, 8) hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2) router_logits = torch.randn(2, 2)
h_out, r_out, _, context_metadata = layer.prepare( prepare_output = layer.prepare(
hidden_states, hidden_states,
router_logits, router_logits,
enable_shared_expert_dp=False, enable_shared_expert_dp=False,
replace_allreduce=False) replace_allreduce=False)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Split due to TP=2 # Split due to TP=2
self.assertEqual(h_out.shape[0], 1) self.assertEqual(h_out.shape[0], 1)
self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8]))
# Mock all_gather # Mock all_gather
def mock_all_gather_func(tensor_list, tensor, group=None): def mock_all_gather_func(tensor_list, tensor, group=None):
@@ -163,7 +179,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
] ]
final_result = layer.finalize(h_out, final_result = layer.finalize(h_out,
reduce_results=False, reduce_results=False,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
# Should concat back # Should concat back
self.assertEqual(final_result.shape[0], 2) self.assertEqual(final_result.shape[0], 2)
@@ -200,12 +216,15 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, _, context_metadata = layer.prepare( prepare_output = layer.prepare(hidden_states, router_logits)
hidden_states, router_logits) h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# After all-gather with DP=2, should double the batch size # After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12) self.assertEqual(h_out.shape[0], 12)
self.assertEqual(r_out.shape[0], 12) self.assertEqual(r_out.shape[0], 12)
self.assertIsNone(padded_hidden_states_shape)
# Finalize with reduce_scatter # Finalize with reduce_scatter
def mock_reduce_scatter_func(tensor, dim): def mock_reduce_scatter_func(tensor, dim):
@@ -215,7 +234,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
mock_dp_group.reduce_scatter = mock_reduce_scatter_func mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out, result = layer.finalize(h_out,
reduce_results=False, reduce_results=False,
context_metadata=context_metadata) padded_hidden_states_shape=padded_hidden_states_shape)
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)

View File

@@ -17,14 +17,62 @@
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import MagicMock, PropertyMock, patch
import numpy as np
import pytest import pytest
import torch import torch
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEMC2CombineMetadata,
MoEQuantParams,
MoERoutingParams,
MoETokenDispatchInput,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
AscendDeviceType, TokenDispatcherWithAll2AllV, AscendDeviceType,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2) TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
from vllm_ascend.quantization.quant_type import QuantType
def build_token_dispatch_input_fixture(
*,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
apply_router_weight_on_input: bool = False,
pertoken_scale: torch.Tensor | None = None,
quant_type: QuantType = QuantType.NONE,
comm_quant_mode: int | None = None,
act_quant_type: torch.dtype | None = None,
) -> MoETokenDispatchInput:
mxfp_spec = None
if quant_type == QuantType.MXFP8:
mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type)
return MoETokenDispatchInput(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
pertoken_scale=pertoken_scale,
),
quant=MoEQuantParams(
quant_type=quant_type,
comm_quant_mode=comm_quant_mode,
mxfp=mxfp_spec,
),
)
class TestTokenDispatcherWithMC2(TestBase): class TestTokenDispatcherWithMC2(TestBase):
@@ -85,7 +133,6 @@ class TestTokenDispatcherWithMC2(TestBase):
def test_init(self): def test_init(self):
self.assertEqual(self.dispatcher.ep_rank_id, 0) self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8) self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2) self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args) self.assertTrue(self.dispatcher.need_extra_args)
@@ -94,10 +141,16 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_ids = torch.randint(0, 8, (10, 1)) topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1) topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
kwargs = self.dispatcher.get_dispatch_mc2_kwargs( topk_weights=topk_weights,
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask) topk_ids=topk_ids,
expert_map=expert_map,
global_redundant_expert_num=0,
apply_router_weight_on_input=False,
pertoken_scale=None,
)
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input)
self.assertIn("x", kwargs) self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs) self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8) self.assertEqual(kwargs["moe_expert_num"], 8)
@@ -111,39 +164,42 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 + return_value=(torch.randn(10, 128), ) * 5 +
(None, None)) as mock_dispatch: (None, None)) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states, token_dispatch_input = build_token_dispatch_input_fixture(
topk_weights, topk_ids, hidden_states=hidden_states,
expert_map) topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
)
output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
mock_dispatch.assert_called_once() mock_dispatch.assert_called_once()
self.assertEqual(output.group_list_type, 0) # group_list_type == 0 self.assertEqual(output.group_list_type, 0) # group_list_type == 0
self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata)
def test_get_combine_mc_kwargs_with_quant(self): def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128) hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1)) topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1) topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
assist_info_for_combine = torch.arange(10) assist_info_for_combine = torch.arange(10)
context_metadata = { combine_metadata = MoEMC2CombineMetadata(
"topk_ids": topk_ids, topk_ids=topk_ids,
"topk_weights": topk_weights, topk_weights=topk_weights,
"expert_map": expert_map, expert_map=expert_map,
"ep_recv_counts": ep_recv_counts, ep_recv_counts=ep_recv_counts,
"mc2_mask": mc2_mask, tp_recv_counts=tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine, assist_info_for_combine=assist_info_for_combine,
"expand_scales": None, expand_scales=None,
"tp_recv_counts": tp_recv_counts dispatch_with_quant=True,
} )
self.dispatcher.need_extra_args = True self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.moe_expert_num = len(expert_map) self.dispatcher.moe_expert_num = len(expert_map)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
context_metadata) combine_metadata)
self.assertIn("tp_send_counts", kwargs) self.assertIn("tp_send_counts", kwargs)
@@ -188,14 +244,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights, token_dispatch_input = build_token_dispatch_input_fixture(
topk_ids, None) hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
# Verify npu_moe_init_routing is called # Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_custom.assert_called_once() self.mock_npu_moe_init_routing_custom.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
self.assertEqual(results.group_list_type, 1) self.assertEqual(results.group_list_type, 1)
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -205,14 +266,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights, token_dispatch_input = build_token_dispatch_input_fixture(
topk_ids, None) hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
# Verify npu_moe_init_routing is called # Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_custom.assert_called_once() self.mock_npu_moe_init_routing_custom.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
self.assertEqual(results.group_list_type, 1) self.assertEqual(results.group_list_type, 1)
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -230,9 +296,12 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states, token_dispatch_input = build_token_dispatch_input_fixture(
topk_weights, topk_ids, hidden_states=hidden_states,
None) topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertEqual(results.group_list_type, 1) self.assertEqual(results.group_list_type, 1)
@@ -252,11 +321,13 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states, token_dispatch_input = build_token_dispatch_input_fixture(
topk_weights, hidden_states=hidden_states,
topk_ids, topk_weights=topk_weights,
None, topk_ids=topk_ids,
with_quant=True) quant_type=QuantType.W8A8,
)
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(results.hidden_states) self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list) self.assertIsNotNone(results.group_list)
@@ -267,40 +338,43 @@ class TestTokenDispatcherWithAllGather(TestBase):
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine_with_expert_map(self): def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128) hidden_states = torch.randn(6, 128)
context_metadata = { combine_metadata = MoEAllGatherCombineMetadata(
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
} restore_shape=torch.Size([6, 128]),
self.dispatcher.original_shape = (6, 128) )
final_hidden_states = self.dispatcher.token_combine( final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
hidden_states, context_metadata).routed_out
self.assertEqual(final_hidden_states.shape, (6, 128)) self.assertEqual(final_hidden_states.shape, (6, 128))
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine_without_expert_map(self): def test_token_combine_without_expert_map(self):
hidden_states = torch.randn(6, 128) hidden_states = torch.randn(6, 128)
context_metadata = { combine_metadata = MoEAllGatherCombineMetadata(
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
} restore_shape=torch.Size([6, 128]),
self.dispatcher.original_shape = (6, 128) )
final_hidden_states = self.dispatcher.token_combine( final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
hidden_states, context_metadata).routed_out
self.mock_npu_moe_token_unpermute.assert_called_once() self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128)) self.assertEqual(final_hidden_states.shape, (6, 128))
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_router_weight(self): def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128) hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]]) topk_ids = torch.tensor([[0], [1], [2]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights, token_dispatch_input = build_token_dispatch_input_fixture(
topk_ids, None) hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=True,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertEqual(results.hidden_states.shape, (6, 128)) self.assertEqual(results.hidden_states.shape, (6, 128))
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
class TestTokenDispatcherWithAll2AllV(TestBase): class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -408,35 +482,39 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
[0, 1], dtype=torch.int32) [0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1] self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, token_dispatch_input = build_token_dispatch_input_fixture(
topk_weights=topk_weights, hidden_states=hidden_states,
topk_ids=topk_ids, topk_weights=topk_weights,
expert_map=expert_map) topk_ids=topk_ids,
expert_map=expert_map,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1) self.assertEqual(result.group_list_type, 1)
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine(self): def test_token_combine(self):
hidden_states = torch.randn(16, 16) hidden_states = torch.randn(16, 16)
context_metadata = { combine_metadata = MoEAllToAllCombineMetadata(
"input_splits": [4, 4], input_splits=np.array([4, 4]),
"output_splits": [4, 4], output_splits=np.array([4, 4]),
"topk_weights": torch.rand(8, 4), topk_weights=torch.rand(8, 4),
"reversed_local_input_permutation_mapping": torch.arange(8), reversed_local_input_permutation_mapping=torch.arange(8),
"reversed_global_input_permutation_mapping": torch.arange(16), reversed_global_input_permutation_mapping=torch.arange(16),
} hidden_shape=torch.Size([8, 16]),
self.dispatcher.hidden_shape = (8, 16) hidden_shape_before_permute=torch.Size([8, 16]),
self.dispatcher.hidden_shape_before_permute = (8, 16) )
self.dispatcher.expert_ids_per_ep_rank = torch.tensor( self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32) [0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1] self.dispatcher.local_expert_indices = [0, 1]
output = self.dispatcher.token_combine(hidden_states, context_metadata) output = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.assertIsNotNone(output) self.assertIsNotNone(output)
self.assertEqual(output.routed_out.shape, (8, 16)) self.assertEqual(output.shape, (8, 16))
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -454,16 +532,20 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
[0, 1], dtype=torch.int32) [0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1] self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, token_dispatch_input = build_token_dispatch_input_fixture(
topk_weights=topk_weights, hidden_states=hidden_states,
topk_ids=topk_ids, topk_weights=topk_weights,
expert_map=expert_map, topk_ids=topk_ids,
with_quant=True) expert_map=expert_map,
quant_type=QuantType.W8A8,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale) self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1) self.assertEqual(result.group_list_type, 1)
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
@pytest.mark.skip( @pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -484,14 +566,16 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
[0, 1], dtype=torch.int32) [0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1] self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, token_dispatch_input = build_token_dispatch_input_fixture(
topk_weights=topk_weights, hidden_states=hidden_states,
topk_ids=topk_ids, topk_weights=topk_weights,
expert_map=expert_map, topk_ids=topk_ids,
with_quant=True) expert_map=expert_map,
quant_type=QuantType.W8A8,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale) self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1) self.assertEqual(result.group_list_type, 1)

View File

@@ -3,9 +3,8 @@ from unittest.mock import Mock, patch
import torch import torch
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod, from vllm_ascend.ascend_forward_context import MoECommType
pack_to_int32, from vllm_ascend.quantization.methods.w4a16 import AscendW4A16FusedMoEMethod, pack_to_int32, unpack_from_int32
unpack_from_int32)
class TestUnpackFromInt32(TestBase): class TestUnpackFromInt32(TestBase):
@@ -268,3 +267,41 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
torch.equal(layer.w13_weight_packed.data, original_w13_data)) torch.equal(layer.w13_weight_packed.data, original_w13_data))
self.assertTrue( self.assertTrue(
torch.equal(layer.w2_weight_packed.data, original_w2_data)) torch.equal(layer.w2_weight_packed.data, original_w2_data))
@patch("vllm_ascend.quantization.methods.w4a16._EXTRA_CTX")
@patch("vllm_ascend.quantization.methods.w4a16.select_experts")
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
tokens = 3
hidden_size = self.output_size
layer = self.build_layer()
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
router_logits = torch.randn(tokens, self.experts, dtype=torch.float32)
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
topk_ids = torch.randint(0, self.experts, (tokens, 2), dtype=torch.int64)
mc2_mask = torch.tensor([1, 0, 1], dtype=torch.bool)
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
mock_select_experts.return_value = (topk_weights, topk_ids)
mock_comm = Mock()
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
mock_extra_ctx.moe_comm_method = mock_comm
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
self.quant_method.apply(
layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=self.experts,
activation="gelu",
apply_router_weight_on_input=True,
mc2_mask=mc2_mask,
pertoken_scale=pertoken_scale,
)
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
self.assertEqual(fused_experts_input.activation, "gelu")
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)

View File

@@ -3,8 +3,8 @@ from unittest.mock import Mock, patch
import torch import torch
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.quantization.methods.w8a8_dynamic import \ from vllm_ascend.ascend_forward_context import MoECommType
AscendW8A8DynamicFusedMoEMethod from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod
class TestAscendW8A8FusedMoEMethod(TestBase): class TestAscendW8A8FusedMoEMethod(TestBase):
@@ -32,8 +32,9 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
mock_ep_group = Mock() mock_ep_group = Mock()
mock_get_ep_group.return_value = mock_ep_group mock_get_ep_group.return_value = mock_ep_group
mock_ascend_config = Mock() mock_ascend_config = Mock()
mock_ascend_config.enable_chunked_prefill = False mock_ascend_config.enable_chunked_prefill = False
mock_ascend_config.multistream_overlap_gate = False
mock_ascend_config.eplb_config = Mock(dynamic_eplb=False)
mock_get_ascend_config.return_value = mock_ascend_config mock_get_ascend_config.return_value = mock_ascend_config
mock_mc2_group = Mock(device_group=0) mock_mc2_group = Mock(device_group=0)
mock_get_mc2_group.return_value = mock_mc2_group mock_get_mc2_group.return_value = mock_mc2_group
@@ -104,3 +105,125 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
new_layer = self.build_layer() new_layer = self.build_layer()
self.quant_method.process_weights_after_loading(new_layer) self.quant_method.process_weights_after_loading(new_layer)
mock_npu_format_cast.assert_called() mock_npu_format_cast.assert_called()
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
tokens = 4
hidden_size = self.hidden_size
layer = torch.nn.Module()
layer.w13_weight = torch.randint(
-8,
8,
(self.num_experts, 2 * self.intermediate_size, hidden_size),
dtype=torch.int8,
)
layer.w2_weight = torch.randint(
-8,
8,
(self.num_experts, hidden_size, self.intermediate_size),
dtype=torch.int8,
)
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
mock_select_experts.return_value = (topk_weights, topk_ids)
mock_comm = Mock()
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
mock_extra_ctx.moe_comm_method = mock_comm
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
self.quant_method.multistream_overlap_gate = False
self.quant_method.in_dtype = torch.float32
self.quant_method.apply(
layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=self.num_experts,
activation="gelu",
apply_router_weight_on_input=True,
mc2_mask=mc2_mask,
pertoken_scale=pertoken_scale,
)
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
self.assertEqual(fused_experts_input.activation, "gelu")
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
self.assertIs(fused_experts_input.topk_weights, topk_weights)
self.assertIs(fused_experts_input.topk_ids, topk_ids)
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_flash_common3_context")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
def test_apply_overlap_gate_uses_fc3_context(
self,
mock_select_experts,
mock_extra_ctx,
mock_get_flash_common3_context,
):
tokens = 4
hidden_size = self.hidden_size
layer = torch.nn.Module()
layer.w13_weight = torch.randint(
-8,
8,
(self.num_experts, 2 * self.intermediate_size, hidden_size),
dtype=torch.int8,
)
layer.w2_weight = torch.randint(
-8,
8,
(self.num_experts, hidden_size, self.intermediate_size),
dtype=torch.int8,
)
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
self.quant_method.multistream_overlap_gate = True
self.quant_method.in_dtype = torch.float32
mock_get_flash_common3_context.return_value = Mock(topk_weights=topk_weights, topk_ids=topk_ids)
mock_comm = Mock()
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
mock_extra_ctx.moe_comm_method = mock_comm
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
self.quant_method.apply(
layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=self.num_experts,
activation="gelu",
apply_router_weight_on_input=True,
mc2_mask=mc2_mask,
pertoken_scale=pertoken_scale,
)
mock_select_experts.assert_not_called()
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
self.assertEqual(fused_experts_input.activation, "gelu")
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
self.assertIs(fused_experts_input.topk_weights, topk_weights)
self.assertIs(fused_experts_input.topk_ids, topk_ids)

View File

@@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.quant_type import QuantType
from .experts_selector import select_experts from .experts_selector import select_experts
from .moe_comm_method import AllGatherCommImpl310 from .moe_comm_method import AllGatherCommImpl310
@@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
moe_comm_method = _EXTRA_CTX.moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
w1=layer.w13_weight, hidden_states=x,
w2=layer.w2_weight, topk_weights=topk_weights,
topk_weights=topk_weights, topk_ids=topk_ids,
topk_ids=topk_ids, w1=layer.w13_weight,
expert_map=expert_map, w2=layer.w2_weight,
apply_router_weight_on_input=apply_router_weight_on_input, quant_type=QuantType.NONE,
dynamic_eplb=False,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
),
) )
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result final_hidden_states += zero_expert_result
@@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE):
assert self.quant_method is not None assert self.quant_method is not None
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported." assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare( prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
) )
hidden_states = prepare_output.hidden_states
router_logits = prepare_output.router_logits
pertoken_scale = prepare_output.pertoken_scale
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Matrix multiply. # Matrix multiply.
fused_experts_results: FusedExpertsResult = self.quant_method.apply( fused_experts_results: FusedExpertsResult = self.quant_method.apply(
@@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE):
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.local_expert_map, expert_map=self.local_expert_map,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
pertoken_scale=pertoken_scale,
) )
routed_out = _EXTRA_CTX.moe_comm_method.finalize( routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out, hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results, reduce_results=self.reduce_results,
context_metadata=context_metadata, padded_hidden_states_shape=padded_hidden_states_shape,
) )
return routed_out return routed_out

View File

@@ -17,8 +17,8 @@ from __future__ import annotations
import torch import torch
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
from .moe_mlp import unified_apply_mlp from .moe_mlp import unified_apply_mlp
from .token_dispatcher import TokenDispatcherWithAllGather310 from .token_dispatcher import TokenDispatcherWithAllGather310
@@ -35,52 +35,12 @@ class AllGatherCommImpl310(AllGatherCommImpl):
to handle the token-to-expert mapping and communication efficiently. to handle the token-to-expert mapping and communication efficiently.
""" """
def fused_experts( # type: ignore[override] def __init__(self, moe_config):
self, super().__init__(moe_config)
hidden_states: torch.Tensor, self.use_fusion_ops = False
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
use_int8_w8a8: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> FusedExpertsResult:
# This method is overridden to use the 310p-specific unified_apply_mlp
# which provides optimized MLP computation for the 310p platform
moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
dispatch_results = self.token_dispatcher.token_dispatch( def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
hidden_states=hidden_states, return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
group_list=dispatch_results.group_list,
group_list_type=dispatch_results.group_list_type,
with_quant=use_int8_w8a8,
)
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
)
return FusedExpertsResult(
routed_out=combine_results.routed_out,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list,
)
def _get_token_dispatcher(self): def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather310( return TokenDispatcherWithAllGather310(

View File

@@ -18,6 +18,8 @@
import torch import torch
import torch_npu import torch_npu
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
def quant_apply_mlp( def quant_apply_mlp(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -66,17 +68,20 @@ def unquant_apply_mlp(
return hidden_states return hidden_states
def unified_apply_mlp( def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
hidden_states: torch.Tensor, hidden_states = mlp_compute_input.hidden_states
w1: torch.Tensor, w1 = mlp_compute_input.weights.w1
w2: torch.Tensor, w2 = mlp_compute_input.weights.w2
group_list: torch.Tensor, w1_scale = mlp_compute_input.weights.w1_scale
w1_scale: torch.Tensor | None = None, w2_scale = mlp_compute_input.weights.w2_scale
w2_scale: torch.Tensor | None = None, group_list = mlp_compute_input.group_list
group_list_type: int = 1, group_list_type = mlp_compute_input.group_list_type
with_quant: bool = False, assert isinstance(w1, torch.Tensor)
) -> torch.Tensor: assert isinstance(w2, torch.Tensor)
if with_quant:
if mlp_compute_input.quant.is_quant:
assert isinstance(w1_scale, torch.Tensor)
assert isinstance(w2_scale, torch.Tensor)
assert w1_scale is not None and w2_scale is not None assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp( return quant_apply_mlp(
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -87,7 +92,11 @@ def unified_apply_mlp(
group_list=group_list, group_list=group_list,
group_list_type=group_list_type, group_list_type=group_list_type,
) )
else:
return unquant_apply_mlp( return unquant_apply_mlp(
hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type hidden_states=hidden_states,
) w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
)

View File

@@ -25,26 +25,27 @@
import torch import torch
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather): class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def token_dispatch( # type: ignore[override] def token_dispatch(
self, self,
hidden_states: torch.Tensor, token_dispatch_input: MoETokenDispatchInput,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
): ):
self.original_shape = hidden_states.shape hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
restore_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel() num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input if apply_router_weight_on_input:
if self.apply_router_weight_on_input:
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)" assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape _, topk = topk_weights.shape
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
@@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
) )
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode group_list_type = 1 # `count` mode
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult( return MoETokenDispatchOutput(
hidden_states=sorted_hidden_states, hidden_states=sorted_hidden_states,
group_list=expert_tokens, group_list=expert_tokens,
group_list_type=group_list_type, group_list_type=group_list_type,
context_metadata=context_metadata, combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=topk_weights,
expanded_row_idx=expanded_row_idx,
restore_shape=restore_shape,
),
) )
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range): def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):

View File

@@ -25,6 +25,7 @@ from vllm.distributed import get_ep_group
from vllm_ascend._310p.fused_moe.experts_selector import select_experts from vllm_ascend._310p.fused_moe.experts_selector import select_experts
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
from .registry import register_scheme from .registry import register_scheme
@@ -95,7 +96,9 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
log2phy: torch.Tensor | None = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None, pertoken_scale: Any | None = None,
**kwargs, activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None) zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -128,15 +131,19 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
w1=layer.w13_weight, hidden_states=x,
w1_scale=layer.w13_weight_scale, topk_weights=topk_weights,
w2=layer.w2_weight, topk_ids=topk_ids,
w2_scale=layer.w2_weight_scale, w1=layer.w13_weight,
topk_weights=topk_weights, w2=layer.w2_weight,
topk_ids=topk_ids, quant_type=self.quant_type,
expert_map=expert_map, dynamic_eplb=False,
use_int8_w8a8=True, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
),
) )
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result final_hidden_states += zero_expert_result

View File

@@ -41,7 +41,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import ( from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_NZ, ACL_FORMAT_FRACTAL_NZ,
enable_sp, enable_sp,
@@ -113,7 +114,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
activation: str = "silu", activation: str = "silu",
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
**kwargs, global_redundant_expert_num: int = 0,
pertoken_scale: torch.Tensor | None = None,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None) zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -167,7 +170,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# (due to signature constraints), we are forced to use a placeholder empty tensor. # (due to signature constraints), we are forced to use a placeholder empty tensor.
# This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor] # This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor]
# or None for scales in non-quantized scenarios. # or None for scales in non-quantized scenarios.
if get_forward_context().moe_comm_type == MoECommType.FUSED_MC2: if _EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2:
w1 = [layer.w13_weight] w1 = [layer.w13_weight]
w1_scale = [torch.tensor([], dtype=torch.int64)] w1_scale = [torch.tensor([], dtype=torch.int64)]
w2 = [layer.w2_weight] w2 = [layer.w2_weight]
@@ -179,21 +182,26 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2_scale = None w2_scale = None
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
w1=w1, hidden_states=x,
w2=w2, topk_weights=topk_weights,
w1_scale=w1_scale, topk_ids=topk_ids,
w2_scale=w2_scale, w1=w1,
w1_bias=layer.w13_bias if self.moe.has_bias else None, w2=w2,
w2_bias=layer.w2_bias if self.moe.has_bias else None, w1_bias=layer.w13_bias if self.moe.has_bias else None,
activation=activation, w2_bias=layer.w2_bias if self.moe.has_bias else None,
topk_weights=topk_weights, quant_type=QuantType.NONE,
topk_ids=topk_ids, dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, global_redundant_expert_num=global_redundant_expert_num,
dynamic_eplb=self.dynamic_eplb, mc2_mask=mc2_mask,
log2phy=log2phy, apply_router_weight_on_input=apply_router_weight_on_input,
mc2_mask=kwargs.get("mc2_mask"), log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
) )
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result final_hidden_states += zero_expert_result
@@ -474,23 +482,23 @@ class AscendFusedMoE(FusedMoE):
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids) set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare( prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled, replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp, enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type, quant_type=self.quant_type,
) )
hidden_states = prepare_output.hidden_states
router_logits = prepare_output.router_logits
mc2_mask = prepare_output.mc2_mask
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
pertoken_scale = prepare_output.pertoken_scale
# Make sure the default stream waits for the gate stream to finish. # Make sure the default stream waits for the gate stream to finish.
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream) torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
pertoken_scale = None
# Matrix multiply. # Matrix multiply.
fused_experts_results: FusedExpertsResult = self.quant_method.apply( fused_experts_results: FusedExpertsResult = self.quant_method.apply(
layer=self, layer=self,
@@ -538,7 +546,7 @@ class AscendFusedMoE(FusedMoE):
routed_out = _EXTRA_CTX.moe_comm_method.finalize( routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out, hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results, reduce_results=self.reduce_results,
context_metadata=context_metadata, padded_hidden_states_shape=padded_hidden_states_shape,
) )
if return_with_event: if return_with_event:

View File

@@ -24,6 +24,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEFusedExpertsInput,
MoEMlpComputeInput,
MoEPrepareOutput,
build_mlp_compute_input,
build_token_dispatch_input,
)
from vllm_ascend.ops.fused_moe.prepare_finalize import ( from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalize, PrepareAndFinalize,
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAll2All,
@@ -36,8 +43,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAllGather, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2, TokenDispatcherWithMC2,
) )
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {} _MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
@@ -90,131 +96,70 @@ class MoECommMethod(ABC):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE, quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare( return self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type hidden_states,
router_logits,
enable_shared_expert_dp,
replace_allreduce,
quant_type,
) )
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize( def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata) hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, padded_hidden_states_shape)
return hidden_states return hidden_states
def fused_experts( def fused_experts(
self, self,
hidden_states: torch.Tensor, fused_experts_input: MoEFusedExpertsInput,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: torch.Tensor | None = None,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
): ):
# Check constraints # Check constraints
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] assert fused_experts_input.hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
moe_comm_method = _EXTRA_CTX.moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context" assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event() before_dispatch_evt = torch.npu.current_stream().record_event()
# Apply log2phy if needed routed_topk_ids = fused_experts_input.topk_ids
if log2phy is not None: if fused_experts_input.routing.log2phy is not None:
topk_ids = log2phy[topk_ids] routed_topk_ids = fused_experts_input.routing.log2phy[routed_topk_ids]
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced
# by different quantization modes will be consolidated into a dataclass in a follow-up. token_dispatch_input = build_token_dispatch_input(
use_mxfp_quant = kwargs.get("use_mxfp_quant", False) fused_experts_input=fused_experts_input,
dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant topk_ids=routed_topk_ids,
act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params( )
**kwargs token_dispatch_output = self.token_dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
mlp_compute_input = build_mlp_compute_input(
fused_experts_input=fused_experts_input,
token_dispatch_output=token_dispatch_output,
use_fusion_ops=self.use_fusion_ops,
) )
dispatch_kwargs = { mlp_output = self._apply_mlp(mlp_compute_input)
"hidden_states": hidden_states,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"expert_map": expert_map,
"global_redundant_expert_num": self.moe_config.global_redundant_expert_num,
"mc2_mask": mc2_mask,
"apply_router_weight_on_input": apply_router_weight_on_input,
"dynamic_eplb": dynamic_eplb,
"pertoken_scale": pertoken_scale,
}
if isinstance(self.token_dispatcher, TokenDispatcherWithMC2):
dispatch_kwargs["with_quant"] = dispatch_with_quant
dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode")
dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None
dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant
else:
dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8
dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs)
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
activation=activation,
group_list=dispatch_results.group_list,
dynamic_scale=dispatch_results.dynamic_scale,
group_list_type=dispatch_results.group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
topk_scales=dispatch_results.topk_scales,
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant,
fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb,
use_mxfp_quant=use_mxfp_quant,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
round_mode=round_mode,
use_bf16=(hidden_states.dtype == torch.bfloat16),
rollback_quant_config=kwargs.get("rollback_quant_config"),
)
before_combine_evt = torch.npu.current_stream().record_event() before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine( routed_out = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata hidden_states=mlp_output,
combine_metadata=token_dispatch_output.combine_metadata,
) )
return FusedExpertsResult( return FusedExpertsResult(
routed_out=combine_results.routed_out, routed_out=routed_out,
before_dispatch_evt=before_dispatch_evt, before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt, before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type, group_list_type=token_dispatch_output.group_list_type,
expert_tokens=dispatch_results.group_list, expert_tokens=token_dispatch_output.group_list,
) )
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
@abstractmethod @abstractmethod
def _get_token_dispatcher(self) -> MoETokenDispatcher: def _get_token_dispatcher(self) -> MoETokenDispatcher:
raise NotImplementedError("_get_token_dispatcher function not implemented.") raise NotImplementedError("_get_token_dispatcher function not implemented.")
@@ -317,54 +262,32 @@ class FusedMC2CommImpl(MoECommMethod):
def fused_experts( def fused_experts(
self, self,
hidden_states: torch.Tensor, fused_experts_input: MoEFusedExpertsInput,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: torch.Tensor | None = None,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
): ):
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." assert not (fused_experts_input.weights.w1_scale is None or fused_experts_input.weights.w2_scale is None), (
"w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
)
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), ( assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
"token_dispatcher must be an instance of TokenDispatcherWithMC2." "token_dispatcher must be an instance of TokenDispatcherWithMC2."
) )
# Apply log2phy if needed # Apply log2phy if needed
if log2phy is not None: topk_ids = fused_experts_input.topk_ids
topk_ids = log2phy[topk_ids] if fused_experts_input.routing.log2phy is not None:
topk_ids = fused_experts_input.routing.log2phy[topk_ids]
expert_tokens = None expert_tokens = None
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states) out = torch.empty_like(fused_experts_input.hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
x=hidden_states, x=fused_experts_input.hidden_states,
weight1=w1, weight1=fused_experts_input.weights.w1,
weight2=w2, weight2=fused_experts_input.weights.w2,
expert_idx=topk_ids, expert_idx=topk_ids,
scale1=w1_scale, scale1=fused_experts_input.weights.w1_scale,
scale2=w2_scale, scale2=fused_experts_input.weights.w2_scale,
probs=topk_weights.to(torch.float32), probs=fused_experts_input.topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name, group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536, max_output_size=65536,
out=out, out=out,
@@ -372,16 +295,16 @@ class FusedMC2CommImpl(MoECommMethod):
) )
expert_tokens = self.expert_token_nums expert_tokens = self.expert_token_nums
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
assert expert_map is not None, "expert_map cannot be None." assert fused_experts_input.routing.expert_map is not None, "expert_map cannot be None."
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
x=hidden_states, x=fused_experts_input.hidden_states,
expert_ids=topk_ids, expert_ids=topk_ids,
gmm1_permuted_weight=w1, gmm1_permuted_weight=fused_experts_input.weights.w1,
gmm1_permuted_weight_scale=w1_scale, gmm1_permuted_weight_scale=fused_experts_input.weights.w1_scale,
gmm2_weight=w2, gmm2_weight=fused_experts_input.weights.w2,
gmm2_weight_scale=w2_scale, gmm2_weight_scale=fused_experts_input.weights.w2_scale,
expert_smooth_scales=None, expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32), expert_scales=fused_experts_input.topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name, group_ep=self.token_dispatcher.moe_all_to_all_group_name,
ep_rank_size=self.token_dispatcher.ep_world_size, ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id, ep_rank_id=self.token_dispatcher.ep_rank_id,

View File

@@ -27,6 +27,7 @@ from vllm_ascend.device.mxfp_compat import (
ensure_mxfp8_moe_available, ensure_mxfp8_moe_available,
) )
from vllm_ascend.ops.activation import AscendSwigluOAIAndMul from vllm_ascend.ops.activation import AscendSwigluOAIAndMul
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
from vllm_ascend.utils import ( from vllm_ascend.utils import (
dispose_tensor, dispose_tensor,
enable_custom_op, enable_custom_op,
@@ -95,27 +96,17 @@ def quant_apply_mlp(
w2_offset: torch.Tensor | None = None, w2_offset: torch.Tensor | None = None,
fusion: bool = False, fusion: bool = False,
dynamic_eplb: bool = False, dynamic_eplb: bool = False,
**kwargs, use_mxfp_quant: bool = False,
act_quant_type: torch.dtype = torch.float8_e4m3fn,
weight_quant_type: torch.dtype | None = None,
scale_type: torch.dtype | None = None,
per_token_scale_type: torch.dtype | None = None,
use_bf16: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
# quantization modes will be consolidated into a dataclass in a follow-up.
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
act_quant_type = torch.float8_e4m3fn
weight_quant_type = None
scale_type = None
per_token_scale_type = None
use_bf16 = True
input_hidden_dtype = hidden_states.dtype input_hidden_dtype = hidden_states.dtype
use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb) use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb)
if use_mxfp_quant: if use_mxfp_quant:
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
scale_type = kwargs.get("scale_type")
per_token_scale_type = kwargs.get("per_token_scale_type")
use_bf16 = kwargs.get("use_bf16", True)
ensure_mxfp8_moe_available("MXFP MoE MLP path") ensure_mxfp8_moe_available("MXFP MoE MLP path")
if w1_scale_bias is not None or w2_scale_bias is not None: if w1_scale_bias is not None or w2_scale_bias is not None:
@@ -393,34 +384,32 @@ def unquant_apply_mlp(
return hidden_states return hidden_states
def unified_apply_mlp( def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
activation: str | None = None,
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
topk_scales: torch.Tensor | None = None,
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True,
dynamic_eplb: bool = False,
**kwargs,
) -> torch.Tensor:
""" """
Unified MoE MLP entry. Unified MoE MLP entry.
Quant path is dispatched by DeviceOperator with explicit quant-type flags. Quant path is dispatched by DeviceOperator with explicit typed kernel flags.
""" """
if not with_quant: hidden_states = mlp_compute_input.hidden_states
group_list = mlp_compute_input.group_list
group_list_type = mlp_compute_input.group_list_type
dynamic_scale = mlp_compute_input.dynamic_scale
topk_scales = mlp_compute_input.topk_scales
w1 = mlp_compute_input.weights.w1
w2 = mlp_compute_input.weights.w2
w1_bias = mlp_compute_input.weights.w1_bias
w2_bias = mlp_compute_input.weights.w2_bias
w1_scale = mlp_compute_input.weights.w1_scale
w2_scale = mlp_compute_input.weights.w2_scale
w1_scale_bias = mlp_compute_input.weights.w1_scale_bias
w2_scale_bias = mlp_compute_input.weights.w2_scale_bias
w1_offset = mlp_compute_input.weights.w1_offset
w2_offset = mlp_compute_input.weights.w2_offset
activation = mlp_compute_input.activation
need_trans = mlp_compute_input.need_trans
dynamic_eplb = mlp_compute_input.dynamic_eplb
fusion = mlp_compute_input.fusion
if not mlp_compute_input.quant.is_quant:
return unquant_apply_mlp( return unquant_apply_mlp(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
@@ -435,13 +424,22 @@ def unified_apply_mlp(
) )
assert w1_scale is not None and w2_scale is not None assert w1_scale is not None and w2_scale is not None
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different act_quant_type = torch.float8_e4m3fn
# quantization modes will be consolidated into a dataclass in a follow-up. weight_quant_type = torch.float8_e4m3fn
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn) scale_type = None
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn) per_token_scale_type = None
scale_type = kwargs.get("scale_type") use_bf16 = hidden_states.dtype == torch.bfloat16
per_token_scale_type = kwargs.get("per_token_scale_type") use_mxfp_quant = mlp_compute_input.quant.is_mxfp
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
if use_mxfp_quant:
mxfp = mlp_compute_input.quant.mxfp
assert mxfp is not None, "mlp_compute_input.quant.mxfp is required when quant_type is MXFP8."
act_quant_type = mxfp.act_quant_type or act_quant_type
weight_quant_type = mxfp.weight_quant_type or weight_quant_type
scale_type = mxfp.scale_dtype
per_token_scale_type = mxfp.per_token_scale_dtype
use_bf16 = mxfp.use_bf16
return quant_apply_mlp( return quant_apply_mlp(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
@@ -457,10 +455,10 @@ def unified_apply_mlp(
w2_offset=w2_offset, w2_offset=w2_offset,
fusion=fusion, fusion=fusion,
dynamic_eplb=dynamic_eplb, dynamic_eplb=dynamic_eplb,
use_mxfp_quant=use_mxfp_quant,
act_quant_type=act_quant_type, act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type, weight_quant_type=weight_quant_type,
scale_type=scale_type, scale_type=scale_type,
per_token_scale_type=per_token_scale_type, per_token_scale_type=per_token_scale_type,
use_mxfp_quant=use_mxfp_quant, use_bf16=use_bf16,
use_bf16=kwargs.get("use_bf16", True),
) )

View File

@@ -0,0 +1,244 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Typed runtime contracts and builders for fused MoE execution.
This module is the single entry point for the runtime payloads used across the
fused MoE pipeline.
Relationship overview:
stage params: reusable sub-payloads
- MoERoutingParams
- MoEQuantParams
- internal MXFP leaf: MoEMxfpParams
stage contracts: stage input/output payloads
prepare
-> MoEPrepareOutput
fused_experts input
-> MoEFusedExpertsInput
|- weights: MoEWeights
|- routing: MoERoutingParams
|- quant: MoEQuantParams
dispatch
input -> MoETokenDispatchInput
output -> MoETokenDispatchOutput[TMoECombineMetadata]
TMoECombineMetadata is one of:
- MoEAllGatherCombineMetadata
- MoEAllToAllCombineMetadata
- MoEMC2CombineMetadata
mlp
input -> MoEMlpComputeInput
combine
output -> torch.Tensor
The helper builders below adapt legacy call sites into these typed contracts.
Only the fused_moe package should need to know about the internal MXFP leaf
dataclass directly.
"""
from __future__ import annotations
import torch
import vllm_ascend.ops.fused_moe.moe_stage_params as _stage_params
from vllm_ascend.ops.fused_moe.moe_stage_contracts import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEFusedExpertsInput,
MoEMC2CombineMetadata,
MoEMlpComputeInput,
MoEPrepareOutput,
MoETokenDispatchInput,
MoETokenDispatchOutput,
MoEWeights,
TMoECombineMetadata,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import (
MoEQuantParams,
MoERoutingParams,
)
from vllm_ascend.quantization.quant_type import QuantType
def _build_mxfp_params(
*,
quant_type: QuantType,
mxfp_act_quant_type: torch.dtype | None = None,
mxfp_weight_quant_type: torch.dtype | None = None,
mxfp_scale_dtype: torch.dtype | None = None,
mxfp_per_token_scale_dtype: torch.dtype | None = None,
mxfp_use_bf16: bool | None = None,
) -> _stage_params.MoEMxfpParams | None:
if quant_type != QuantType.MXFP8:
return None
has_explicit_mxfp_args = any(
value is not None
for value in (
mxfp_act_quant_type,
mxfp_weight_quant_type,
mxfp_scale_dtype,
mxfp_per_token_scale_dtype,
mxfp_use_bf16,
)
)
if not has_explicit_mxfp_args:
raise ValueError("primitive MXFP params are required when quant_type is QuantType.MXFP8.")
return _stage_params.MoEMxfpParams(
act_quant_type=mxfp_act_quant_type,
weight_quant_type=mxfp_weight_quant_type,
scale_dtype=mxfp_scale_dtype,
per_token_scale_dtype=mxfp_per_token_scale_dtype,
use_bf16=True if mxfp_use_bf16 is None else mxfp_use_bf16,
)
def build_fused_experts_input(
*,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
quant_type: QuantType,
dynamic_eplb: bool,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
log2phy: torch.Tensor | None = None,
pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
need_trans: bool = False,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
comm_quant_mode: int | None = None,
mxfp_act_quant_type: torch.dtype | None = None,
mxfp_weight_quant_type: torch.dtype | None = None,
mxfp_scale_dtype: torch.dtype | None = None,
mxfp_per_token_scale_dtype: torch.dtype | None = None,
mxfp_use_bf16: bool | None = None,
w1_scale: list[torch.Tensor] | torch.Tensor | None = None,
w2_scale: list[torch.Tensor] | torch.Tensor | None = None,
w1_scale_bias: torch.Tensor | None = None,
w2_scale_bias: torch.Tensor | None = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
) -> MoEFusedExpertsInput:
return MoEFusedExpertsInput(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
weights=MoEWeights(
w1=w1,
w2=w2,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
),
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
),
activation=activation,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb,
quant=MoEQuantParams(
quant_type=quant_type,
comm_quant_mode=comm_quant_mode,
mxfp=_build_mxfp_params(
quant_type=quant_type,
mxfp_act_quant_type=mxfp_act_quant_type,
mxfp_weight_quant_type=mxfp_weight_quant_type,
mxfp_scale_dtype=mxfp_scale_dtype,
mxfp_per_token_scale_dtype=mxfp_per_token_scale_dtype,
mxfp_use_bf16=mxfp_use_bf16,
),
),
)
def build_token_dispatch_input(
*,
fused_experts_input: MoEFusedExpertsInput,
topk_ids: torch.Tensor | None = None,
) -> MoETokenDispatchInput:
return MoETokenDispatchInput(
hidden_states=fused_experts_input.hidden_states,
topk_weights=fused_experts_input.topk_weights,
topk_ids=fused_experts_input.topk_ids if topk_ids is None else topk_ids,
routing=fused_experts_input.routing,
quant=fused_experts_input.quant,
)
def build_mlp_compute_input(
*,
fused_experts_input: MoEFusedExpertsInput,
token_dispatch_output: MoETokenDispatchOutput[TMoECombineMetadata],
use_fusion_ops: bool,
) -> MoEMlpComputeInput:
if fused_experts_input.quant.is_mxfp and fused_experts_input.quant.mxfp is None:
raise ValueError("fused_experts_input.quant.mxfp is required when quant_type is QuantType.MXFP8.")
return MoEMlpComputeInput(
hidden_states=token_dispatch_output.hidden_states,
group_list=token_dispatch_output.group_list,
group_list_type=token_dispatch_output.group_list_type,
dynamic_scale=token_dispatch_output.dynamic_scale,
topk_scales=token_dispatch_output.topk_scales,
weights=fused_experts_input.weights,
quant=fused_experts_input.quant,
fusion=fused_experts_input.quant.quant_type in (QuantType.W8A8, QuantType.MXFP8) and use_fusion_ops,
activation=fused_experts_input.activation,
need_trans=fused_experts_input.need_trans,
dynamic_eplb=fused_experts_input.dynamic_eplb,
)
__all__ = [
"MoEAllGatherCombineMetadata",
"MoEAllToAllCombineMetadata",
"MoEFusedExpertsInput",
"MoEMC2CombineMetadata",
"MoEMlpComputeInput",
"MoEPrepareOutput",
"MoEQuantParams",
"MoERoutingParams",
"MoETokenDispatchInput",
"MoETokenDispatchOutput",
"MoEWeights",
"TMoECombineMetadata",
"build_fused_experts_input",
"build_token_dispatch_input",
"build_mlp_compute_input",
]

View File

@@ -0,0 +1,154 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from dataclasses import dataclass
from typing import Generic, TypeVar
import numpy as np
import torch
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEQuantParams, MoERoutingParams
TMoECombineMetadata = TypeVar("TMoECombineMetadata")
# prepare -> fused_experts
@dataclass(frozen=True, slots=True)
class MoEPrepareOutput:
"""Typed output from prepare stage."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
mc2_mask: torch.Tensor | None
padded_hidden_states_shape: torch.Size | None
pertoken_scale: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)
class MoEWeights:
"""Dense and quantized weight payloads consumed by MoE execution."""
w1: torch.Tensor | list[torch.Tensor]
w2: torch.Tensor | list[torch.Tensor]
w1_bias: torch.Tensor | None = None
w2_bias: torch.Tensor | None = None
w1_scale: torch.Tensor | list[torch.Tensor] | None = None
w2_scale: torch.Tensor | list[torch.Tensor] | None = None
w1_scale_bias: torch.Tensor | None = None
w2_scale_bias: torch.Tensor | None = None
w1_offset: torch.Tensor | None = None
w2_offset: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)
class MoEFusedExpertsInput:
"""Top-level input for the routed experts pipeline."""
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
weights: MoEWeights
routing: MoERoutingParams
quant: MoEQuantParams
activation: str = "silu"
need_trans: bool = False
dynamic_eplb: bool = False
@dataclass(frozen=True, slots=True)
class MoETokenDispatchInput:
"""Input to token dispatch."""
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
routing: MoERoutingParams
quant: MoEQuantParams
# dispatch carry-over state consumed by combine
@dataclass(frozen=True, slots=True)
class MoEMC2CombineMetadata:
topk_ids: torch.Tensor
topk_weights: torch.Tensor
expert_map: torch.Tensor | None
ep_recv_counts: torch.Tensor
tp_recv_counts: torch.Tensor
assist_info_for_combine: torch.Tensor
expand_scales: torch.Tensor | None
dispatch_with_quant: bool
@dataclass(frozen=True, slots=True)
class MoEAllGatherCombineMetadata:
topk_weights: torch.Tensor
expanded_row_idx: torch.Tensor
restore_shape: torch.Size
@dataclass(frozen=True, slots=True)
class MoEAllToAllCombineMetadata:
input_splits: np.ndarray
output_splits: np.ndarray
topk_weights: torch.Tensor
reversed_local_input_permutation_mapping: torch.Tensor
reversed_global_input_permutation_mapping: torch.Tensor | None
hidden_shape: torch.Size
hidden_shape_before_permute: torch.Size
@dataclass(frozen=True, slots=True)
class MoETokenDispatchOutput(Generic[TMoECombineMetadata]):
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
combine_metadata: TMoECombineMetadata
dynamic_scale: torch.Tensor | None = None
topk_scales: torch.Tensor | None = None
# dispatch -> mlp -> combine
@dataclass(frozen=True, slots=True)
class MoEMlpComputeInput:
"""Input to MLP compute."""
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
dynamic_scale: torch.Tensor | None
topk_scales: torch.Tensor | None
weights: MoEWeights
quant: MoEQuantParams
fusion: bool
activation: str = "silu"
need_trans: bool = False
dynamic_eplb: bool = False
__all__ = [
"MoEPrepareOutput",
"MoEWeights",
"MoEFusedExpertsInput",
"MoETokenDispatchInput",
"MoEMC2CombineMetadata",
"MoEAllGatherCombineMetadata",
"MoEAllToAllCombineMetadata",
"MoETokenDispatchOutput",
"MoEMlpComputeInput",
"TMoECombineMetadata",
]

View File

@@ -0,0 +1,86 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from dataclasses import dataclass
import torch
from vllm_ascend.quantization.quant_type import QuantType
@dataclass(frozen=True, slots=True)
class MoERoutingParams:
"""Routing and dispatch side inputs for one MoE invocation.
`pertoken_scale` is intentionally kept here even though it is not a pure
routing concept. It is used by pre-quantized activation flows, currently
the AllGather + EP W8A8 prepare path, where prepare emits per-token
activation scales and dispatch needs to carry them forward so the MLP
quant path can reuse those scales instead of requantizing activations.
"""
expert_map: torch.Tensor | None
global_redundant_expert_num: int
mc2_mask: torch.Tensor | None
apply_router_weight_on_input: bool
log2phy: torch.Tensor | None = None
# Precomputed activation scales from prepare stage for quantized dispatch.
pertoken_scale: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)
class MoEMxfpParams:
"""Internal MXFP-only precision settings used by fused_moe runtime."""
act_quant_type: torch.dtype | None = None
weight_quant_type: torch.dtype | None = None
scale_dtype: torch.dtype | None = None
per_token_scale_dtype: torch.dtype | None = None
use_bf16: bool = True
@dataclass(frozen=True, slots=True)
class MoEQuantParams:
"""Quant mode, backend override, and optional internal MXFP leaf config."""
quant_type: QuantType = QuantType.NONE
comm_quant_mode: int | None = None
mxfp: MoEMxfpParams | None = None
@property
def is_quant(self) -> bool:
return self.quant_type != QuantType.NONE
@property
def is_mxfp(self) -> bool:
return self.quant_type == QuantType.MXFP8
@property
def is_int_quant(self) -> bool:
return self.quant_type in (QuantType.W8A8, QuantType.W4A8)
@property
def dispatch_with_quant(self) -> bool:
return self.quant_type in (QuantType.W8A8, QuantType.W4A8, QuantType.MXFP8)
__all__ = [
"MoERoutingParams",
"MoEMxfpParams",
"MoEQuantParams",
]

View File

@@ -31,7 +31,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
@@ -64,7 +65,7 @@ class PrepareAndFinalize(ABC):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE, quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
""" """
Prepare tensors before MoE computation. May involve: Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries - Padding to align communication boundaries
@@ -79,16 +80,20 @@ class PrepareAndFinalize(ABC):
quant_type: none, w8a8, w4a8 or mxfp8 quant_type: none, w8a8, w4a8 or mxfp8
Returns: Returns:
Tuple of: MoEPrepareOutput:
- processed hidden_states (may be padded/sliced/broadcasted) - processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted) - processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops) - optional communication mask (e.g., mc2_mask for sparse ops)
- optional context metadata (e.g., saved split_hidden_states for finalization) - optional padded hidden state shape for finalization
- optional per-token scale for quantized path
""" """
raise NotImplementedError("Prepare not implemented.") raise NotImplementedError("Prepare not implemented.")
def finalize( def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Finalize MoE output. May involve: Finalize MoE output. May involve:
@@ -130,7 +135,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE, quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
""" """
Preparation steps: Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size. 1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -140,7 +145,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns: Returns:
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All. MoEPrepareOutput where `mc2_mask` is None for All2All path.
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -162,12 +167,19 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape} return MoEPrepareOutput(
hidden_states=hidden_states,
return hidden_states, router_logits, None, context_metadata router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=padded_hidden_states_shape,
pertoken_scale=None,
)
def finalize( def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:
@@ -180,12 +192,11 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
if not (self.enable_shared_expert_dp or self.replace_allreduce): if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1: if self.tp_size > 1:
assert context_metadata is not None assert padded_hidden_states_shape is not None
# Cannot reuse `split_hidden_states` from prepare phase as it # Cannot reuse `split_hidden_states` from prepare phase as it
# may share memory with original hidden_states. Since shared # may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause # experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data. # in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
gathered_hidden_states = torch.empty( gathered_hidden_states = torch.empty(
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
) )
@@ -227,7 +238,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE, quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
""" """
Preparation steps: Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context. 1. Fetch `mc2_mask` and target padding length from forward context.
@@ -238,7 +249,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns: Returns:
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded. MoEPrepareOutput, possibly sliced/padded.
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -267,11 +278,13 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
context_metadata = { return MoEPrepareOutput(
"padded_hidden_states_shape": padded_hidden_states_shape, hidden_states=hidden_states,
} router_logits=router_logits,
mc2_mask=mc2_mask,
return hidden_states, router_logits, mc2_mask, context_metadata padded_hidden_states_shape=padded_hidden_states_shape,
pertoken_scale=None,
)
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
@@ -303,13 +316,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE, quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
""" """
Preparation steps: Preparation steps:
AllGather hidden_states and router_logits to form global tensors. AllGather hidden_states and router_logits to form global tensors.
Returns: Returns:
Tuple of (global_hidden_states, global_router_logits, None) MoEPrepareOutput with global tensors.
""" """
if enable_sp(): if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type) return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
@@ -318,7 +331,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
def _prepare_with_ep_group( def _prepare_with_ep_group(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
pertoken_scale = None pertoken_scale = None
if quant_type == QuantType.W8A8: if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
@@ -342,10 +355,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream) torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
if pertoken_scale is not None: return MoEPrepareOutput(
return (hidden_states, pertoken_scale), router_logits, None, None hidden_states=hidden_states,
router_logits=router_logits,
return hidden_states, router_logits, None, None mc2_mask=None,
padded_hidden_states_shape=None,
pertoken_scale=pertoken_scale,
)
def _prepare_with_dp_group( def _prepare_with_dp_group(
self, self,
@@ -354,7 +370,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE, quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: ) -> MoEPrepareOutput:
""" """
Preparation steps: Preparation steps:
1. Fetch max token count across DP group from forward context. 1. Fetch max token count across DP group from forward context.
@@ -362,7 +378,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
3. All-gather across DP group to form global input tensor. 3. All-gather across DP group to form global input tensor.
Returns: Returns:
Tuple of (global_hidden_states, global_router_logits, None, None) MoEPrepareOutput with global tensors.
""" """
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1: if self.moe_config.dp_size > 1:
@@ -396,10 +412,19 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
dim=0, dim=0,
) )
return hidden_states, router_logits, None, None return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=None,
pertoken_scale=None,
)
def finalize( def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:

View File

@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from typing import Generic
import torch import torch
import torch_npu import torch_npu
@@ -31,25 +31,18 @@ from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEMC2CombineMetadata,
MoETokenDispatchInput,
MoETokenDispatchOutput,
TMoECombineMetadata,
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
@dataclass class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
class TokenDispatchResult:
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
dynamic_scale: torch.Tensor | None = field(default=None)
topk_scales: torch.Tensor | None = field(default=None)
context_metadata: dict = field(default_factory=dict)
@dataclass
class TokenCombineResult:
routed_out: torch.Tensor
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
""" """
Initialize the MoE Token Dispatcher. Initialize the MoE Token Dispatcher.
@@ -73,27 +66,21 @@ class MoETokenDispatcher(ABC):
@abstractmethod @abstractmethod
def token_dispatch( def token_dispatch(
self, self,
hidden_states: torch.Tensor, token_dispatch_input: MoETokenDispatchInput,
topk_weights: torch.Tensor, ) -> MoETokenDispatchOutput[TMoECombineMetadata]:
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
) -> TokenDispatchResult:
raise NotImplementedError("Dispatch function not implemented.") raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod @abstractmethod
def token_combine( def token_combine(
self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None self,
) -> TokenCombineResult: hidden_states: torch.Tensor,
combine_metadata: TMoECombineMetadata,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError("Combine function not implemented.") raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher): class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
device_group = get_mc2_group().device_group device_group = get_mc2_group().device_group
@@ -110,7 +97,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance. # improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled() self.need_expert_scale = is_hierarchical_communication_enabled()
self.with_quant = False
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
# dispatch & combine operators with different input num_tokens per rank. # dispatch & combine operators with different input num_tokens per rank.
@@ -131,25 +117,23 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def get_dispatch_mc2_kwargs( def get_dispatch_mc2_kwargs(
self, self,
hidden_states: torch.Tensor, token_dispatch_input: MoETokenDispatchInput,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0,
**kwargs,
): ):
use_mxfp_quant = kwargs.get("use_mxfp_quant", False) hidden_states = token_dispatch_input.hidden_states
comm_quant_mode = kwargs.get("comm_quant_mode") topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
comm_quant_mode = token_dispatch_input.quant.comm_quant_mode
assert expert_map is not None, "expert_map is required for MC2 token dispatch."
# NOTE: quant_mode differs by quant feature: # NOTE: quant_mode differs by quant feature:
# - Legacy int communication quantization uses quant_mode=2. # - Legacy int communication quantization uses quant_mode=2.
# - A5 MXFP8 communication uses quant_mode=4. # - A5 MXFP8 communication uses quant_mode=4.
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
if comm_quant_mode is not None: if comm_quant_mode is not None:
quant_mode = comm_quant_mode quant_mode = comm_quant_mode
elif self.with_quant: elif token_dispatch_input.quant.dispatch_with_quant:
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2 quant_mode = 4 if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp else 2
else: else:
quant_mode = 0 quant_mode = 0
self.moe_expert_num = len(expert_map) + global_redundant_expert_num self.moe_expert_num = len(expert_map) + global_redundant_expert_num
@@ -178,10 +162,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_rank_id": 0, "tp_rank_id": 0,
} }
) )
if self.a5_need_extra_args and use_mxfp_quant: if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp:
y_dtype = kwargs.get("y_dtype") y_dtype = torch.float8_e4m3fn
if self.with_quant: if (
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype token_dispatch_input.quant.mxfp is not None
and token_dispatch_input.quant.mxfp.act_quant_type is not None
):
y_dtype = token_dispatch_input.quant.mxfp.act_quant_type
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype}) stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
if self.need_expert_scale or self.a5_need_extra_args: if self.need_expert_scale or self.a5_need_extra_args:
stage1_kwargs.update( stage1_kwargs.update(
@@ -195,22 +182,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def token_dispatch( def token_dispatch(
self, self,
hidden_states: torch.Tensor, token_dispatch_input: MoETokenDispatchInput,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
): ):
self.with_quant = with_quant kwargs_mc2 = self.get_dispatch_mc2_kwargs(token_dispatch_input)
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
)
output = ( output = (
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2) torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
if self.enable_dispatch_v2 if self.enable_dispatch_v2
@@ -227,33 +201,32 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expand_scales, expand_scales,
) = output[0:7] ) = output[0:7]
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": expand_scales,
}
group_list_type = 0 group_list_type = 0
return TokenDispatchResult( return MoETokenDispatchOutput(
hidden_states=expand_x, hidden_states=expand_x,
dynamic_scale=dynamic_scale, dynamic_scale=dynamic_scale,
group_list=expert_token_nums, group_list=expert_token_nums,
group_list_type=group_list_type, group_list_type=group_list_type,
context_metadata=context_metadata, combine_metadata=MoEMC2CombineMetadata(
topk_ids=token_dispatch_input.topk_ids,
topk_weights=token_dispatch_input.topk_weights,
expert_map=token_dispatch_input.routing.expert_map,
ep_recv_counts=ep_recv_counts,
tp_recv_counts=tp_recv_counts,
assist_info_for_combine=assist_info_for_combine,
expand_scales=expand_scales,
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
),
) )
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict): def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, combine_metadata: MoEMC2CombineMetadata):
expert_map = context_metadata["expert_map"] expert_map = combine_metadata.expert_map
topk_ids = context_metadata["topk_ids"] topk_ids = combine_metadata.topk_ids
topk_weights = context_metadata["topk_weights"] topk_weights = combine_metadata.topk_weights
ep_recv_counts = context_metadata["ep_recv_counts"] ep_recv_counts = combine_metadata.ep_recv_counts
tp_recv_counts = context_metadata["tp_recv_counts"] tp_recv_counts = combine_metadata.tp_recv_counts
assist_info_for_combine = context_metadata["assist_info_for_combine"] assist_info_for_combine = combine_metadata.assist_info_for_combine
expand_scales = context_metadata["expand_scales"] expand_scales = combine_metadata.expand_scales
assert expert_map is not None assert expert_map is not None
@@ -267,7 +240,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"global_bs": self.global_bs, "global_bs": self.global_bs,
} }
if self.with_quant: if combine_metadata.dispatch_with_quant:
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
stage3_kwargs = { stage3_kwargs = {
@@ -296,52 +269,44 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage3_kwargs) kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2 return kwargs_mc2
def token_combine(self, hidden_states, context_metadata, bias=None): def token_combine(self, hidden_states, combine_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata) kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, combine_metadata)
combined_output = ( combined_output = (
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
if self.enable_dispatch_v2 if self.enable_dispatch_v2
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
) )
return TokenCombineResult( return combined_output
routed_out=combined_output,
)
class TokenDispatcherWithAllGather(MoETokenDispatcher): class TokenDispatcherWithAllGather(MoETokenDispatcher[MoEAllGatherCombineMetadata]):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens") self.max_num_tokens = kwargs.get("max_num_tokens")
num_experts_local = kwargs.get("num_local_experts", 0) num_experts_local = kwargs.get("num_local_experts", 0)
self.num_experts_local = ( self.num_experts_local = (
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local) num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
) )
self.original_shape = None
self.with_quant = False
def token_dispatch( def token_dispatch(
self, self,
hidden_states: torch.Tensor, token_dispatch_input: MoETokenDispatchInput,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
): ):
self.with_quant = with_quant with_quant = token_dispatch_input.quant.is_int_quant
self.original_shape = hidden_states.shape hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
pertoken_scale = token_dispatch_input.routing.pertoken_scale
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
restore_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel() num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
if self.apply_router_weight_on_input: if apply_router_weight_on_input:
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)" assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape _, topk = topk_weights.shape
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
@@ -365,35 +330,37 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_tokens_num_type=1, expert_tokens_num_type=1,
expert_tokens_num_flag=True, expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx], active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant and pertoken_scale is None else -1, quant_mode=1 if with_quant and pertoken_scale is None else -1,
) )
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode group_list_type = 1 # `count` mode
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult( return MoETokenDispatchOutput(
hidden_states=sorted_hidden_states, hidden_states=sorted_hidden_states,
dynamic_scale=pertoken_scale if self.with_quant else None, dynamic_scale=pertoken_scale if with_quant else None,
group_list=expert_tokens, group_list=expert_tokens,
group_list_type=group_list_type, group_list_type=group_list_type,
context_metadata=context_metadata, combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=topk_weights,
expanded_row_idx=expanded_row_idx,
restore_shape=restore_shape,
),
) )
def token_combine(self, hidden_states, context_metadata, bias=None): def token_combine(self, hidden_states, combine_metadata, bias=None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute( final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states, permuted_tokens=hidden_states,
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]), sorted_indices=torch.abs(combine_metadata.expanded_row_idx),
probs=context_metadata["topk_weights"], probs=combine_metadata.topk_weights,
) )
if len(self.original_shape) == 3: if len(combine_metadata.restore_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape) final_hidden_states = final_hidden_states.view(combine_metadata.restore_shape)
# these values are no longer used, so they need to be set to None for memory release. # these values are no longer used, so they need to be set to None for memory release.
return TokenCombineResult(routed_out=final_hidden_states) return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher): class TokenDispatcherWithAll2AllV(MoETokenDispatcher[MoEAllToAllCombineMetadata]):
""" """
The implementation of the AlltoAll-based token dispatcher, which handles token The implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation dispatching on the sequence level instead of token level. The core of this implementation
@@ -402,12 +369,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0) self.num_local_experts = kwargs.get("num_local_experts", 0)
self.hidden_shape = None
self.hidden_shape_before_permute = None
assert self.num_local_experts > 0, "Expected at least one expert" assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1: if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor( self.expert_ids_per_ep_rank = torch.tensor(
@@ -432,19 +395,12 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
def token_dispatch( def token_dispatch(
self, self,
hidden_states: torch.Tensor, token_dispatch_input: MoETokenDispatchInput,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
): ):
self.with_quant = with_quant with_quant = token_dispatch_input.quant.is_int_quant
self.hidden_shape = hidden_states.shape hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
( (
permutated_local_input_tokens, permutated_local_input_tokens,
@@ -452,12 +408,13 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
tokens_per_expert, tokens_per_expert,
input_splits, input_splits,
output_splits, output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices, global_input_tokens_local_experts_indices,
hidden_shape,
hidden_shape_before_permute,
) = self._dispatch_preprocess(hidden_states, topk_ids) ) = self._dispatch_preprocess(hidden_states, topk_ids)
dynamic_scale_after_all2all = None dynamic_scale_after_all2all = None
if self.with_quant: if with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens) permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale, output_splits, input_splits, self.ep_group dynamic_scale, output_splits, input_splits, self.ep_group
@@ -474,64 +431,66 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
# Postprocess # Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = ( global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
self._dispatch_postprocess( self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices,
with_quant,
) )
) )
context_metadata = { return MoETokenDispatchOutput(
"input_splits": input_splits,
"output_splits": output_splits,
"topk_weights": topk_weights,
"reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
}
return TokenDispatchResult(
hidden_states=global_input_tokens, hidden_states=global_input_tokens,
dynamic_scale=dynamic_scale_final, dynamic_scale=dynamic_scale_final,
group_list=tokens_per_expert, group_list=tokens_per_expert,
group_list_type=1, group_list_type=1,
context_metadata=context_metadata, combine_metadata=MoEAllToAllCombineMetadata(
input_splits=input_splits,
output_splits=output_splits,
topk_weights=topk_weights,
reversed_local_input_permutation_mapping=reversed_local_input_permutation_mapping,
reversed_global_input_permutation_mapping=reversed_global_input_permutation_mapping,
hidden_shape=hidden_shape,
hidden_shape_before_permute=hidden_shape_before_permute,
),
) )
def token_combine(self, hidden_states, context_metadata, bias=None): def token_combine(self, hidden_states, combine_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata # 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states, context_metadata) hidden_states = self._combine_preprocess(hidden_states, combine_metadata)
# 2. AllToAll # 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all( _, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, hidden_states,
context_metadata["input_splits"], combine_metadata.input_splits,
context_metadata["output_splits"], combine_metadata.output_splits,
self.ep_group, self.ep_group,
) )
handle.wait() handle.wait()
hidden_states.untyped_storage().resize_(0) hidden_states.untyped_storage().resize_(0)
# 3. Postprocess using metadata # 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens, context_metadata) output = self._combine_postprocess(permutated_local_input_tokens, combine_metadata)
return TokenCombineResult(routed_out=output) return output
def _dispatch_preprocess(self, hidden_states, topk_ids): def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None hidden_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.size(-1)) hidden_states = hidden_states.view(-1, hidden_states.size(-1))
( (
tokens_per_expert, tokens_per_expert,
input_splits, input_splits,
output_splits, output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices, global_input_tokens_local_experts_indices,
num_out_tokens,
) = self._preprocess(topk_ids) ) = self._preprocess(topk_ids)
hidden_shape_before_permute = hidden_states.shape
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
tokens=hidden_states, tokens=hidden_states,
indices=topk_ids, indices=topk_ids,
num_out_tokens=self.num_out_tokens, num_out_tokens=num_out_tokens,
) )
return ( return (
@@ -540,15 +499,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
tokens_per_expert, tokens_per_expert,
input_splits, input_splits,
output_splits, output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices, global_input_tokens_local_experts_indices,
hidden_shape,
hidden_shape_before_permute,
) )
def _preprocess(self, topk_ids: torch.Tensor): def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts) num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
ep_size = self.ep_size ep_size = self.ep_size
self.num_out_tokens = topk_ids.numel() num_out_tokens = topk_ids.numel()
input_splits = ( input_splits = (
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts) num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
@@ -585,19 +545,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
num_tokens_per_local_expert, num_tokens_per_local_expert,
input_splits, input_splits,
output_splits, output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices, global_input_tokens_local_experts_indices,
num_out_tokens,
) )
def _dispatch_postprocess( def _dispatch_postprocess(
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices, with_quant
): ):
# Early return if no local experts or no tokens # Early return if no local experts or no tokens
if self.num_local_experts <= 1: if self.num_local_experts <= 1:
return global_input_tokens, dynamic_scale_after_all2all, None return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case # Handle quantized case
if self.with_quant: if with_quant:
assert global_input_tokens_local_experts_indices is not None, ( assert global_input_tokens_local_experts_indices is not None, (
"global_input_tokens_local_experts_indices must be provided" "global_input_tokens_local_experts_indices must be provided"
) )
@@ -612,20 +572,26 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
) )
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor: def _combine_preprocess(
self, hidden_states: torch.Tensor, combine_metadata: MoEAllToAllCombineMetadata
) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input # Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1: rev_global = combine_metadata.reversed_global_input_permutation_mapping
rev_global = context_metadata["reversed_global_input_permutation_mapping"] if hidden_states.shape[0] > 0 and self.num_local_experts > 1 and rev_global is not None:
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global) hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
return hidden_states return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor: def _combine_postprocess(
self,
permutated_local_input_tokens: torch.Tensor,
combine_metadata: MoEAllToAllCombineMetadata,
) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output # Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute( output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens, permuted_tokens=permutated_local_input_tokens,
sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32), sorted_indices=combine_metadata.reversed_local_input_permutation_mapping.to(torch.int32),
probs=context_metadata["topk_weights"], probs=combine_metadata.topk_weights,
restore_shape=self.hidden_shape_before_permute, restore_shape=combine_metadata.hidden_shape_before_permute,
) )
output = output.view(self.hidden_shape) output = output.view(combine_metadata.hidden_shape)
return output return output

View File

@@ -16,24 +16,30 @@
# #
"""Ascend quantization module. """Ascend quantization module.
This module provides quantization support for Ascend NPU. This module intentionally avoids eager imports so that importing lightweight
submodules (for example ``quant_type``) does not trigger heavy registration
Supported quantization tools: paths and circular imports during startup.
- ModelSlim: Use AscendModelSlimConfig
- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig
Public API:
- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig
- For scheme implementations, import from vllm_ascend.quantization.methods
""" """
# LLM-Compressor (compressed_tensors) quantization config from typing import TYPE_CHECKING, Any
from .compressed_tensors_config import AscendCompressedTensorsConfig
# ModelSlim quantization config if TYPE_CHECKING:
from .modelslim_config import AscendModelSlimConfig from .compressed_tensors_config import AscendCompressedTensorsConfig
from .modelslim_config import AscendModelSlimConfig
__all__ = [ __all__ = [
"AscendModelSlimConfig", "AscendModelSlimConfig",
"AscendCompressedTensorsConfig", "AscendCompressedTensorsConfig",
] ]
def __getattr__(name: str) -> Any:
if name == "AscendModelSlimConfig":
from .modelslim_config import AscendModelSlimConfig
return AscendModelSlimConfig
if name == "AscendCompressedTensorsConfig":
from .compressed_tensors_config import AscendCompressedTensorsConfig
return AscendCompressedTensorsConfig
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -255,28 +255,34 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num=0, global_redundant_expert_num=0,
**kwargs, pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.quant_method.apply( return self.quant_method.apply(
layer, layer=layer,
x, x=x,
router_logits, router_logits=router_logits,
top_k, top_k=top_k,
renormalize, renormalize=renormalize,
use_grouped_topk, use_grouped_topk=use_grouped_topk,
global_num_experts, global_num_experts=global_num_experts,
expert_map, expert_map=expert_map,
topk_group, topk_group=topk_group,
num_expert_group, num_expert_group=num_expert_group,
custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func, scoring_func=scoring_func,
routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
is_prefill, is_prefill=is_prefill,
enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
log2phy, log2phy=log2phy,
global_redundant_expert_num, global_redundant_expert_num=global_redundant_expert_num,
**kwargs, pertoken_scale=pertoken_scale,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
mc2_mask=mc2_mask,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

View File

@@ -18,19 +18,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from enum import Enum
from typing import Any from typing import Any
import torch import torch
from vllm_ascend.quantization.quant_type import QuantType
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
MXFP8 = 3
class AscendLinearScheme(ABC): class AscendLinearScheme(ABC):
@@ -245,7 +237,10 @@ class AscendMoEScheme(ABC):
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward computation for MoE layer. """Forward computation for MoE layer.
@@ -268,7 +263,10 @@ class AscendMoEScheme(ABC):
enable_force_load_balance: Whether to force load balancing. enable_force_load_balance: Whether to force load balancing.
log2phy: Logical to physical expert mapping. log2phy: Logical to physical expert mapping.
global_redundant_expert_num: Number of redundant experts. global_redundant_expert_num: Number of redundant experts.
**kwargs: Additional keyword arguments. pertoken_scale: Optional per-token activation scale from prepare stage.
activation: Expert MLP activation type.
apply_router_weight_on_input: Whether to pre-scale hidden states by router weights.
mc2_mask: Optional mask used by MC2 dispatch.
Returns: Returns:
Output tensor after MoE computation. Output tensor after MoE computation.

View File

@@ -25,8 +25,9 @@ from vllm.config import get_current_vllm_config
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from .base import AscendMoEScheme from .base import AscendMoEScheme, QuantType
from .registry import register_scheme from .registry import register_scheme
@@ -103,6 +104,8 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
class AscendW4A16FusedMoEMethod(AscendMoEScheme): class AscendW4A16FusedMoEMethod(AscendMoEScheme):
"""FusedMoE method for Ascend W4A16.""" """FusedMoE method for Ascend W4A16."""
quant_type: QuantType = QuantType.W4A16
def __init__(self) -> None: def __init__(self) -> None:
self.transpose_weight = True self.transpose_weight = True
self.num_bits = 4 # dtype = torch.int4 self.num_bits = 4 # dtype = torch.int4
@@ -192,7 +195,10 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = True, enable_force_load_balance: bool = True,
log2phy: torch.Tensor | None = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)" "Number of global experts mismatch (excluding redundancy)"
@@ -217,20 +223,26 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
w1=layer.w13_weight_packed, hidden_states=x,
w2=layer.w2_weight_packed, topk_weights=topk_weights,
w1_scale=layer.w13_weight_scale, topk_ids=topk_ids,
w2_scale=layer.w2_weight_scale, w1=layer.w13_weight_packed,
w1_offset=layer.w13_weight_offset, w2=layer.w2_weight_packed,
w2_offset=layer.w2_weight_offset, quant_type=self.quant_type,
topk_weights=topk_weights, dynamic_eplb=self.dynamic_eplb,
topk_ids=topk_ids, expert_map=expert_map,
use_int4_w4a16=True, global_redundant_expert_num=global_redundant_expert_num,
expert_map=expert_map, mc2_mask=mc2_mask,
log2phy=log2phy, apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb, log2phy=log2phy,
mc2_mask=kwargs.get("mc2_mask"), pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_offset=layer.w13_weight_offset,
w2_offset=layer.w2_weight_offset,
)
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

View File

@@ -28,6 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -343,7 +344,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)" "Number of global experts mismatch (excluding redundancy)"
@@ -377,20 +381,26 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
w1=[layer.w13_weight], hidden_states=x,
w2=[layer.w2_weight], topk_weights=topk_weights,
w1_scale=[layer.w13_weight_scale], topk_ids=topk_ids,
w2_scale=[layer.w2_weight_scale], w1=[layer.w13_weight],
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None, w2=[layer.w2_weight],
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None, quant_type=self.quant_type,
topk_weights=topk_weights, dynamic_eplb=self.dynamic_eplb,
topk_ids=topk_ids, expert_map=expert_map,
use_int4_w4a8=True, global_redundant_expert_num=global_redundant_expert_num,
expert_map=expert_map, mc2_mask=mc2_mask,
log2phy=log2phy, apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb, log2phy=log2phy,
mc2_mask=kwargs.get("mc2_mask"), pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
)
) )
def process_scale(self, weight: torch.Tensor, scale, per_group_scale): def process_scale(self, weight: torch.Tensor, scale, per_group_scale):

View File

@@ -29,6 +29,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -182,7 +183,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
log2phy: torch.Tensor | None = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None, pertoken_scale: Any | None = None,
**kwargs, activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None) zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -249,19 +252,24 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale] w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
pertoken_scale=pertoken_scale, hidden_states=x,
w1=w1, topk_weights=topk_weights,
w1_scale=w1_scale, topk_ids=topk_ids,
w2=w2, w1=w1,
w2_scale=w2_scale, w2=w2,
topk_weights=topk_weights, quant_type=self.quant_type,
topk_ids=topk_ids, dynamic_eplb=self.dynamic_eplb,
use_int8_w8a8=True, expert_map=expert_map,
expert_map=expert_map, global_redundant_expert_num=global_redundant_expert_num,
log2phy=log2phy, mc2_mask=mc2_mask,
dynamic_eplb=self.dynamic_eplb, apply_router_weight_on_input=apply_router_weight_on_input,
mc2_mask=kwargs.get("mc2_mask"), log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
)
) )
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result final_hidden_states += zero_expert_result

View File

@@ -31,6 +31,7 @@ from vllm_ascend.device.mxfp_compat import (
ensure_mxfp8_moe_available, ensure_mxfp8_moe_available,
) )
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme from .registry import register_scheme
@@ -170,7 +171,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = True, enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
expected = global_num_experts - global_redundant_expert_num expected = global_num_experts - global_redundant_expert_num
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)" assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
@@ -198,23 +202,29 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, fused_experts_input=build_fused_experts_input(
w1=layer.w13_weight, hidden_states=x,
w1_scale=layer.w13_weight_scale, topk_weights=topk_weights,
w2=layer.w2_weight, topk_ids=topk_ids,
w2_scale=layer.w2_weight_scale, w1=layer.w13_weight,
topk_weights=topk_weights, w2=layer.w2_weight,
topk_ids=topk_ids, quant_type=self.quant_type,
use_int8_w8a8=False, dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num,
dynamic_eplb=self.dynamic_eplb, mc2_mask=mc2_mask,
mc2_mask=kwargs.get("mc2_mask"), apply_router_weight_on_input=apply_router_weight_on_input,
use_mxfp_quant=True, log2phy=log2phy,
act_quant_type=torch.float8_e4m3fn, pertoken_scale=pertoken_scale,
weight_quant_type=torch.float8_e4m3fn, activation=activation,
scale_type=FLOAT8_E8M0FNU_DTYPE, mxfp_act_quant_type=torch.float8_e4m3fn,
per_token_scale_type=FLOAT8_E8M0FNU_DTYPE, mxfp_weight_quant_type=torch.float8_e4m3fn,
mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
mxfp_use_bf16=(x.dtype == torch.bfloat16),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
) )
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):

View File

@@ -0,0 +1,33 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Shared quantization enum definitions.
Keep this module lightweight and side-effect free so core runtime modules can
import QuantType without triggering heavy quantization package initialization.
"""
from enum import Enum
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
MXFP8 = 3
W4A16 = 4