[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)

### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.1.1
- vLLM main:
704432af3c

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
weichen
2025-08-28 10:13:35 +08:00
committed by GitHub
parent 936c102105
commit 320edde2df
10 changed files with 1066 additions and 1639 deletions

View File

@@ -24,10 +24,12 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe import fused_experts
from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
TokenDispatcherWithAllGather
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
@@ -35,6 +37,38 @@ TOP_KS = [2, 6]
DEVICE = ["npu"] DEVICE = ["npu"]
def apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
) -> torch.Tensor:
w1 = w1.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map): def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
@@ -60,7 +94,7 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE) @pytest.mark.parametrize("device", DEVICE)
def test_fused_experts( def test_token_dispatcher_with_all_gather(
m: int, m: int,
n: int, n: int,
k: int, k: int,
@@ -75,19 +109,23 @@ def test_fused_experts(
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype) score = torch.randn((m, e), device=device, dtype=dtype)
expert_map = None
local_e = e
w1_local = w1
w2_local = w2
if ep_size > 1: if ep_size > 1:
local_e = e // ep_size local_e = e // ep_size
e_ids = torch.randint(0, e_ids = torch.arange(local_e * 0,
e, (local_e, ), local_e * (0 + 1),
device=device, device=device,
dtype=torch.int32) dtype=torch.int32)
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32) expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32) expert_map[e_ids] = torch.arange(local_e,
w1 = w1[e_ids] device=device,
w2 = w2[e_ids] dtype=torch.int32)
else: w1_local = w1[e_ids]
e_map = None w2_local = w2[e_ids]
score = torch.softmax(score, dim=-1, dtype=dtype) score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk) topk_weights, topk_ids = torch.topk(score, topk)
@@ -99,11 +137,42 @@ def test_fused_experts(
dtype=torch.int32, dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous()) ).view(topk, -1).permute(1, 0).contiguous())
output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk, dispatcher_kwargs = {
e_map) "num_experts": e,
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map) "top_k": topk,
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem "num_local_experts": local_e,
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) }
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
w2=w2_local,
group_list=group_list,
group_list_type=group_list_type)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch.testing.assert_close(combined_output,
torch_output,
atol=4e-2,
rtol=1)
torch.npu.empty_cache() torch.npu.empty_cache()

View File

@@ -22,7 +22,6 @@ from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.models.deepseek_v2 import ( from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2DecoderLayer, CustomDeepseekV2ForCausalLM,
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinear,
@@ -115,7 +114,8 @@ def mock_distributed():
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \ _PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group): patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
patch("torch.npu.current_device", return_value=0):
yield yield
@@ -266,54 +266,3 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
kv_lora_rank=16, kv_lora_rank=16,
prefix="layers.1.self_attn") prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj") assert hasattr(attn, "q_proj")
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
torch.randn(2, 128))
base_config.n_routed_experts = 4
layer = CustomDeepseekV2DecoderLayer(config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
cache_config=CacheConfig(),
quant_config=None)
assert isinstance(layer.mlp, CustomDeepseekV2MoE)
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
hidden_states, residual = layer(positions, x, None)
assert hidden_states.shape == (2, 4, 128)
base_config.n_routed_experts = None
layer = CustomDeepseekV2DecoderLayer(config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
quant_config=None)
assert isinstance(layer.mlp, CustomDeepseekV2MLP)
def test_custom_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
model = CustomDeepseekV2ForCausalLM(vllm_config=vllm_config)
input_ids = torch.randint(0, 10000, (2, 4))
positions = torch.arange(4).repeat(2, 1)
with patch.object(model.model,
"forward",
return_value=torch.randn(2, 4, 128)):
output = model(input_ids, positions)
assert output.shape == (2, 4, 128)
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
with patch(
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
):
loaded = model.load_weights(weights)
assert loaded is not None

View File

@@ -22,11 +22,15 @@ import torch_npu
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm_ascend.ascend_forward_context import _get_fused_moe_state import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE, from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod) AscendUnquantizedFusedMoEMethod,
unified_apply_mlp)
from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True) adapt_patch(True)
@@ -56,7 +60,73 @@ def mock_npu_format_cast(weight_data, format):
@pytest.fixture @pytest.fixture
def mock_dist_env(mocker: MockerFixture): def mock_dist_env(mocker: MockerFixture):
# init dist env patch mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
captured_dispatchers = {}
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
@@ -66,12 +136,10 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \ patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \ patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce', patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
return_value=torch.randn(5, 32)), \ patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter',
return_value=torch.randn(5, 32)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \ return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config', patch('vllm_ascend.ops.fused_moe.get_ascend_config',
@@ -82,22 +150,31 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm_ascend.ops.fused_moe.determine_expert_map', patch('vllm_ascend.ops.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context', patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=MagicMock( return_value=mock_forward_context_obj), \
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
)), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock( return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2), parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4), scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048) model_config=MagicMock(max_model_len=2048)
)): )), \
yield patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture @pytest.fixture
def mock_moe_env(mocker: MockerFixture): def mock_moe_env(mocker: MockerFixture):
# init moe env patch
with patch('torch_npu.npu_moe_gating_top_k', return_value=( with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2), torch.randn(8, 2),
@@ -144,7 +221,6 @@ def mock_moe_env(mocker: MockerFixture):
@pytest.fixture @pytest.fixture
def default_moe_config(): def default_moe_config():
"""default moe config"""
return { return {
'num_experts': 8, 'num_experts': 8,
'top_k': 2, 'top_k': 2,
@@ -188,7 +264,6 @@ class MockQuantMethod(nn.Module):
class MockFusedMoEMethod(FusedMoEMethodBase): class MockFusedMoEMethod(FusedMoEMethodBase):
# TODO(bnell): also pass quant_config?
moe = MagicMock() moe = MagicMock()
def __init__(self): def __init__(self):
@@ -223,13 +298,11 @@ class TestAscendFusedMoe:
assert hasattr(layer, 'w13_weight') assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight') assert hasattr(layer, 'w2_weight')
# check group_topk
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
error_config = default_moe_config.copy() error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True error_config['use_grouped_topk'] = True
layer = AscendFusedMoE(**error_config) layer = AscendFusedMoE(**error_config)
# check scoring_func
with pytest.raises(ValueError): with pytest.raises(ValueError):
error_config = default_moe_config.copy() error_config = default_moe_config.copy()
error_config['scoring_func'] = "random" error_config['scoring_func'] = "random"
@@ -254,14 +327,7 @@ class TestAscendFusedMoe:
[None, None, False, 1, None], [None, None, True, 5, 1], [None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]]) [None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param): def test_forward(self, mock_dist_env, default_moe_config, others_param):
"""
1 test has shared_experts
2 test has top_k
3 test is_prefill is true
4 test single num_tokens(decode)
5 test ep_size is 1 and is_prefill is true
6 test ep_size is 1 and is_prefill is False
"""
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32) inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8) router_logits = torch.randn(num_tokens, 8)
@@ -327,25 +393,42 @@ class TestAscendUnquantizedFusedMoEMethod:
[[256, 4], [128, 1], [128, 1], [128, 4]]) [[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env, def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param): mock_moe_env, others_param):
"""
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size = others_param global_num_experts, ep_size = others_param
is_prefill = False is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256 is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1)) ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.get_forward_context", with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context): return_value=forward_context):
moe_method.ep_size = ep_size moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2) x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8) router_logits = torch.randn(8, 8)
layer = MagicMock() layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1) local_num_experts = 2
layer.w2_weight = torch.randn(16, 8, 1) hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer, result = moe_method.apply(layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
@@ -354,29 +437,38 @@ class TestAscendUnquantizedFusedMoEMethod:
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
is_prefill=is_prefill) is_prefill=is_prefill)
if ep_size == 1: expected_shape = (16, 2)
assert result.shape == (16, 2)
else: assert result.shape == expected_shape
assert result.shape == x.shape
@pytest.mark.parametrize("others_param", @pytest.mark.parametrize("others_param",
[[16, False], [1, True], [1, False], [4, False]]) [[16, False], [1, True], [1, False], [4, False]])
def test_apply_with_expert_map(self, moe_method, mock_dist_env, def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param): mock_moe_env, others_param):
"""
1 test use_select_experts and use fused_expters_with_mc2
2 test use_select_experts and fused_experts_with_all2all_buffer
3 test use_select_experts and fused_experts_with_all2all
4 test use_select_experts and fused_experts
"""
ep_size, alltoall_buffer = others_param ep_size, alltoall_buffer = others_param
is_prefill = False is_prefill = False
forward_context = MagicMock(
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True)) if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER", with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
alltoall_buffer), \ alltoall_buffer), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3): patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2) x = torch.randn(8, 2, 2)
@@ -386,8 +478,16 @@ class TestAscendUnquantizedFusedMoEMethod:
if alltoall_buffer: if alltoall_buffer:
moe_method.max_model_len = 1 moe_method.max_model_len = 1
layer = MagicMock() layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1) local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer, result = moe_method.apply(layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
@@ -397,10 +497,9 @@ class TestAscendUnquantizedFusedMoEMethod:
expert_map=expert_map, expert_map=expert_map,
is_prefill=is_prefill) is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1: expected_shape = (16, 2)
assert result.shape == (16, 2)
else: assert result.shape == expected_shape
assert result.shape == x.shape
class TestExpertsSelector: class TestExpertsSelector:
@@ -426,3 +525,239 @@ class TestExpertsSelector:
assert topk_weights.shape == (8, 2) assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2) assert topk_ids.shape == (8, 2)
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.get_mc2_group')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_is_310p,
mock_get_mc2_group,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_mc2_group = MagicMock()
mock_get_mc2_group.return_value = mock_mc2_group
mock_is_310p.return_value = False
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 20),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_dequant.return_value = (torch.randn(10,
40,
dtype=torch.bfloat16),
torch.randn(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
w1_scale = torch.randn(5, 40, dtype=torch.float32)
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None)
mock_get_forward_context.assert_called()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_dequant.assert_called_once()
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = False
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.float16)
], [torch.randn(10, 20, dtype=torch.float16)]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales)
mock_get_forward_context.assert_called()
self.assertFalse(mock_forward_context.with_quant)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.bfloat16)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None)
mock_get_forward_context.assert_called()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization_310p(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = False
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = True
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
[mock_gmm2_out]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales)
mock_get_forward_context.assert_called()
self.assertFalse(mock_forward_context.with_quant)
mock_is_310p.assert_called_once()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)

View File

@@ -25,8 +25,8 @@ from tests.ut.base import PytestBase, TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig, AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2) TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher,
from vllm_ascend.utils import adapt_patch # noqa E402 get_token_dispatcher, setup_token_dispatchers)
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase): class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
@@ -90,7 +90,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.forward_context = MagicMock() self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1]) self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch( self.forward_context_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context", "vllm.forward_context.get_forward_context",
return_value=self.forward_context) return_value=self.forward_context)
self.forward_context_patch.start() self.forward_context_patch.start()
@@ -100,28 +100,18 @@ class TestTokenDispatcherWithMC2(TestBase):
return_value=AscendSocVersion.A3) return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start() self.ascend_soc_version_patch.start()
# Mock get_ascend_config()
self.ascend_config = MagicMock()
self.ascend_config.torchair_graph_config.enabled = False
self.ascend_config_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
return_value=self.ascend_config)
self.ascend_config_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs) self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self): def tearDown(self):
self.mc2_group_patch.stop() self.mc2_group_patch.stop()
self.forward_context_patch.stop() self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop() self.ascend_soc_version_patch.stop()
self.ascend_config_patch.stop()
def test_init(self): def test_init(self):
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
self.assertEqual(self.dispatcher.ep_rank_id, 0) self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8) self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.torchair_graph_enabled)
self.assertFalse(self.dispatcher.with_quant) self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2) self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args) self.assertTrue(self.dispatcher.need_extra_args)
@@ -149,9 +139,10 @@ class TestTokenDispatcherWithMC2(TestBase):
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch: return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states, output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids, topk_weights, topk_ids,
expert_map) self.row_idx, expert_map)
mock_dispatch.assert_called_once() mock_dispatch.assert_called_once()
self.assertEqual(output[0], 1) # group_list_type == 1 self.assertEqual(output["group_list_type"],
1) # group_list_type == 1
def test_token_dispatch_with_shared_experts_and_quant(self): def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock() self.shared_experts = MagicMock()
@@ -166,20 +157,13 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5): return_value=(torch.randn(10, 128), ) * 5):
with patch( self.dispatcher.token_dispatch(self.hidden_states,
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", self.topk_weights,
autospec=True): torch.randint(0, 8, (10, 1)),
with patch( self.row_idx,
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", torch.tensor(
autospec=True) as mock_wait: [0, 1, 2, 3, 4, 5, 6, 7]),
self.dispatcher.token_dispatch( shared_experts=self.shared_experts)
self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
mock_wait.assert_any_call(self.hidden_states,
self.topk_weights)
def test_get_combine_mc_kwargs_with_quant(self): def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True self.dispatcher.with_quant = True
@@ -213,13 +197,7 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_combine_v2", with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)): return_value=torch.randn(10, 128)):
with patch( self.dispatcher.token_combine(self.hidden_states)
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True):
self.dispatcher.token_combine(self.hidden_states)
class TestTokenDispatcherWithAllGather(TestBase): class TestTokenDispatcherWithAllGather(TestBase):
@@ -257,6 +235,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
) )
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128) self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self): def tearDown(self):
self.patcher_moe_init_routing.stop() self.patcher_moe_init_routing.stop()
@@ -268,14 +247,14 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch( results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
hidden_states, topk_weights, topk_ids, None) topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called # Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once() self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args args, kwargs = self.mock_moe_init_routing.call_args
self.assertEqual(group_list_type, 0) self.assertEqual(results["group_list_type"], 0)
def test_token_dispatch_with_quant(self): def test_token_dispatch_with_quant(self):
kwargs = { kwargs = {
@@ -292,11 +271,11 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch( results = self.dispatcher_quant.token_dispatch(hidden_states,
hidden_states, topk_weights, topk_ids, None) topk_weights, topk_ids,
self.row_idx, None)
# Verify quant mode returns group_list_type=1 self.assertEqual(results["group_list_type"], 0)
self.assertEqual(group_list_type, 0)
def test_token_combine_with_expert_map(self): def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -337,19 +316,9 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]]) topk_ids = torch.tensor([[0], [1], [2]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch( results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
hidden_states, topk_weights, topk_ids, None) topk_ids, None)
self.assertEqual(sorted_hidden_states.shape, (6, 128)) self.assertEqual(results["hidden_states"].shape, (6, 128))
def test_token_dispatch_invalid_topk_when_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
with self.assertRaises(AssertionError):
self.dispatcher.token_dispatch(
hidden_states, topk_weights,
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
class TestTokenDispatcherWithAll2AllV(TestBase): class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -443,6 +412,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
num_experts=4, num_experts=4,
num_local_experts=2, num_local_experts=2,
with_quant=False) with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self): def test_token_dispatch(self):
hidden_states = torch.randn(8, 16) hidden_states = torch.randn(8, 16)
@@ -457,6 +427,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map) expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["hidden_states"])
@@ -504,6 +475,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map) expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["hidden_states"])
@@ -532,6 +504,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map) expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["hidden_states"])
@@ -553,9 +526,126 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states, result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy) log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"]) self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1) self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -1,82 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu,
mock_grouped_matmul,
mock_moe_re_routing,
mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
row_idx_len = self.num_tokens * 8
row_idx = (torch.arange(
0,
row_idx_len,
dtype=torch.int32,
).view(8, -1).permute(1, 0).contiguous())
result = fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
row_idx=row_idx,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))

View File

@@ -46,6 +46,18 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
return FusedMoEState.MC2 return FusedMoEState.MC2
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
if ep_size == 1:
return "TokenDispatcherWithAllGather"
if ep_size < 16:
return "TokenDispatcherWithAll2AllV"
if with_prefill:
return "TokenDispatcherWithAll2AllV"
return "TokenDispatcherWithMC2"
@contextmanager @contextmanager
def set_ascend_forward_context( def set_ascend_forward_context(
attn_metadata: Any, attn_metadata: Any,
@@ -87,6 +99,14 @@ def set_ascend_forward_context(
forward_context.fused_moe_state = fused_moe_state forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run forward_context.in_profile_run = in_profile_run
with_quant = vllm_config.quant_config is not None
forward_context.with_quant = with_quant
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
get_token_dispatcher
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
dispatcher = get_token_dispatcher(dispatcher_name)
forward_context.token_dispatcher = dispatcher
# NOTE: This cannot be set using set_forward_context # NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing # due to multiple warmups before actual capturing
forward_context.capturing = False forward_context.capturing = False

View File

@@ -16,14 +16,14 @@
# Adapted from vllm/tests/kernels/test_moe.py # Adapted from vllm/tests/kernels/test_moe.py
import os import os
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
from torch import nn from torch import nn
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
@@ -49,9 +49,8 @@ from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
dispose_tensor, get_all_reduce_merge_state, get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p) get_rm_router_logits_state, is_310p)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -122,149 +121,6 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
return topk_ids_pad, unpad_indices return topk_ids_pad, unpad_indices
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
moe_parallel_config: FusedMoEParallelConfig,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
shared_act = shared_experts.act_fn(shared_gate_up)
w1 = w1.transpose(1, 2)
group_list = expert_token_nums.to(torch.int64)
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[expand_x],
weight=[w1],
split_item=2,
# 1 means count mode, to avoid cumulative operation of the group list
group_list_type=1,
group_type=0,
group_list=group_list,
)[0]
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=1,
group_type=0,
group_list=group_list,
)[0]
# moeCombine
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
tp_recv_counts = output[5]
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None:
return hidden_states
else:
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return hidden_states, shared_hidden_states
def apply_mlp( def apply_mlp(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@@ -318,248 +174,6 @@ def apply_mlp(
return hidden_states return hidden_states
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
if expert_map is not None:
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)
gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()
expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
else:
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
w1 = w1.transpose(1, 2)
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)[0]
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)[0]
if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all_buffer(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
max_model_len: int,
global_batch_size: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
max_model_len // ep_group.world_size +
1) * top_k * 2
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
expanded_expert_idx, global_num_experts, ep_group.world_size,
max_row_per_ep_rank, num_tokens, top_k)
hidden_states_pad_idx = torch.zeros(
expert_idx_buffer_scatter.shape,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
non_pad_len = torch.sum((expert_idx_buffer_scatter
!= global_num_experts).to(torch.int32))
hidden_states_pad_idx[expert_idx_buffer_scatter !=
global_num_experts] = torch.arange(
non_pad_len,
dtype=expert_idx_buffer_scatter.dtype,
device=hidden_states.device)
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
expert_idx_buffer_gather = torch.empty_like(
expert_idx_buffer_scatter,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
hidden_states_buffer_gather = torch.empty_like(
hidden_states_buffer_scatter,
dtype=hidden_states_buffer_scatter.dtype,
device=hidden_states_buffer_scatter.device)
dist.all_to_all_single(expert_idx_buffer_gather,
expert_idx_buffer_scatter,
group=ep_group.device_group)
dist.all_to_all_single(hidden_states_buffer_gather,
hidden_states_buffer_scatter,
group=ep_group.device_group)
mask = expert_idx_buffer_gather != global_num_experts
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
global_num_experts // ep_group.world_size)
hidden_states = hidden_states_buffer_gather[mask]
idx_type = local_expert_idx.dtype
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0
hidden_states = apply_mlp(hidden_states,
w1,
w2,
expert_tokens,
group_list_type=group_list_type)
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
hidden_states = hidden_states[resorted_idx]
hidden_states_scatter = torch.zeros(
(mask.shape[0], hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states_scatter[mask] = hidden_states
hidden_states_gatter = torch.empty_like(
hidden_states_scatter,
dtype=hidden_states_scatter.dtype,
device=hidden_states_scatter.device)
dist.all_to_all_single(hidden_states_gatter,
hidden_states_scatter,
group=ep_group.device_group)
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
global_num_experts]
if hidden_states_gatter.shape[0] != row_idx_len:
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states[unpad_indices != -1] = hidden_states_gatter
else:
# TODO: Reorder device memory 2 times here, replace the current
hidden_states = hidden_states_gatter
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
def fused_experts_moge( def fused_experts_moge(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@@ -651,188 +265,228 @@ def fused_experts_moge(
return final_hidden_states return final_hidden_states
def fused_experts_with_all2allv( def quant_apply_mlp(hidden_states: torch.Tensor,
token_dispatcher, w1: torch.Tensor,
probs, w1_scale: torch.Tensor,
routing_map, w2: torch.Tensor,
hidden_states: torch.Tensor, w2_scale: torch.Tensor,
w1: torch.Tensor, group_list: torch.Tensor,
w2: torch.Tensor, dynamic_scale: torch.Tensor = None,
): group_list_type: int = 1,
# Enable moe alltoallv, it's a balanced policy for precision and efficiency. w1_scale_bias: torch.Tensor = None,
(share_experts_output, dispatched_input, w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
tokens_per_expert) = (token_dispatcher.token_permutation( if dynamic_scale is None:
hidden_states, probs, routing_map)) unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) hidden_states)
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) # Dispose the original unquantized hidden states
return output # to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
max_num_tokens: Optional[int] = None,
) -> torch.Tensor:
"""
Fused experts with top-k routing.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
"""
# if torch.distributed.get_rank() == 0:
# print(w1.shape)
# print(hidden_states.shape)
original_shape = hidden_states.shape
# assert len(original_shape) == 2
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
dtype = hidden_states.dtype
device = hidden_states.device
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfloat16 are supported"
if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
if expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(num_tokens,
device=device,
dtype=torch.int64).unsqueeze(1).expand(
-1, top_k).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = expert_map[experts_flat]
# Filter valid token-expert pairs
mask = local_experts_flat != -1
filtered_weights = torch.where(
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
filtered_experts = torch.where(
mask, local_experts_flat,
torch.full_like(local_experts_flat,
num_experts)).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
sorted_token_indices = token_indices[sort_indices]
sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
token_counts = token_counts[:num_experts]
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
else: else:
active_num = max_num_tokens if max_num_tokens is not None else num_tokens pertoken_scale = dynamic_scale
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=active_num)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( bias1, bias2 = None, None
expanded_expert_idx, num_experts) _output_dtype = w2_scale.dtype
expert_tokens = expert_tokens.to(torch.int64)
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
if w1_scale_bias is None and is_mc2:
w1_scale = w1_scale.to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
group_list = torch.cat(
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list_type = 1
bias1 = [w1_scale_bias]
bias2 = [w2_scale_bias]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale],
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
return hidden_states
def unquant_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
w1 = w1.transpose(1, 2) w1 = w1.transpose(1, 2)
gate_up_out_list = torch_npu.npu_grouped_matmul( gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states], x=[hidden_states],
weight=[w1], weight=[w1],
split_item=2, split_item=2,
group_list_type=0, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=expert_tokens, group_list=group_list,
)[0] )[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) if topk_scales is not None:
gate_up_out *= topk_scales
w2 = w2.transpose(1, 2) w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out], x=[gate_up_out],
weight=[w2], weight=[w2],
split_item=2, split_item=2,
group_list_type=0, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=expert_tokens, group_list=group_list,
)[0] )[0]
return hidden_states
if expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros(*original_shape, def unified_apply_mlp(
device=hidden_states.device, hidden_states: torch.Tensor,
dtype=dtype) w1: torch.Tensor,
w1_scale: torch.Tensor,
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] w2: torch.Tensor,
# This created multiple NaN and index_add_ will mix them up which harms accuracy w2_scale: torch.Tensor,
# remove this mask and filter after it being fixed group_list: torch.Tensor,
num_valid_tokens = mask.sum() dynamic_scale: torch.Tensor = None,
valid_token_mask = torch.arange( group_list_type: int = 1,
0, sorted_token_indices.shape[0], w1_scale_bias: torch.Tensor = None,
device=device).unsqueeze(1) < num_valid_tokens w2_scale_bias: torch.Tensor = None,
valid_output = torch.where( topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_token_mask, weighted_down_out, if get_forward_context().with_quant:
torch.zeros_like(weighted_down_out)).to(dtype) return quant_apply_mlp(hidden_states=hidden_states,
final_hidden_states.index_add_(0, sorted_token_indices, valid_output) w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias)
else: else:
scales = torch.ones_like( return unquant_apply_mlp(hidden_states=hidden_states,
topk_weights) if apply_router_weight_on_input else topk_weights w1=w1,
# TODO: Reorder device memory 2 times here, replace the current w2=w2,
# implementation here when suitable operators become available. group_list=group_list,
final_hidden_states = torch_npu.npu_moe_finalize_routing( group_list_type=group_list_type,
down_out_list, topk_scales=topk_scales)
skip1=None,
skip2=None,
bias=None,
scales=scales,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
def unified_fused_experts_eager(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
w1_scale: Optional[torch.Tensor] = None,
w1_scale_bias: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w2_scale_bias: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
token_dispatcher = get_forward_context().token_dispatcher
results = token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input)
expert_output = unified_apply_mlp(
hidden_states=results["hidden_states"],
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=results["group_list"],
dynamic_scale=results.get("dynamic_scale"),
group_list_type=results.get("group_list_type"),
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"))
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states return final_hidden_states
@@ -914,65 +568,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance and not self.use_aclgraph: if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state return unified_fused_experts_eager(hidden_states=x,
w1=layer.w13_weight,
if fused_moe_state == FusedMoEState.MC2: w2=layer.w2_weight,
return fused_experts_with_mc2( topk_weights=topk_weights,
hidden_states=x, topk_ids=topk_ids,
w1=layer.w13_weight, row_idx=row_idx,
w2=layer.w2_weight, expert_map=expert_map,
moe_parallel_config=self.moe.moe_parallel_config, shared_experts=shared_experts,
topk_weights=topk_weights, mc2_mask=kwargs.get(
topk_ids=topk_ids, "mc2_mask", None))
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map)
elif MOE_ALL2ALL_BUFFER:
return fused_experts_with_all2all_buffer(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
max_model_len=self.max_model_len,
global_batch_size=self.global_batch_size,
expert_map=expert_map,
ep_group=get_ep_group())
elif fused_moe_state == FusedMoEState.All2AllSeq:
token_dispatcher = kwargs.get("token_dispatcher")
return fused_experts_with_all2allv(
token_dispatcher=token_dispatcher,
probs=topk_weights,
routing_map=topk_ids,
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
)
else:
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=get_ep_group())
class AscendFusedMoE(FusedMoE): class AscendFusedMoE(FusedMoE):
@@ -1154,6 +759,19 @@ class AscendFusedMoE(FusedMoE):
self.token_dispatcher, token_dispatcher1 self.token_dispatcher, token_dispatcher1
] ]
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
with_quant = quant_config is not None
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers(
ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_global_redundant_experts=self.global_redundant_expert_num,
num_local_experts=self.local_num_experts,
with_quant=with_quant)
def naive_multicast(self, x: torch.Tensor, def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor): cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2) assert (len(x.shape) == 2)

View File

@@ -22,21 +22,18 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Any, Dict, Optional
import torch import torch
import torch_npu import torch_npu
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.distributed.tensor_parallel import ( from vllm_ascend.distributed.tensor_parallel import (
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
all_to_all_sp2hp, gather_from_sequence_parallel_region, all_to_all_sp2hp, gather_from_sequence_parallel_region,
reduce_scatter_last_dim_to_tensor_parallel_region) reduce_scatter_last_dim_to_tensor_parallel_region)
from vllm_ascend.ops.comm_utils import async_all_to_all from vllm_ascend.ops.comm_utils import async_all_to_all
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
@@ -460,6 +457,31 @@ class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
return output, None return output, None
_Dispatchers: Dict[str, Any] = {}
def _register_token_dispatcher(dispatcher: Any):
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
def get_token_dispatcher(name: str):
return _Dispatchers.get(name)
def setup_token_dispatchers(ep_size: int, **kwargs):
existing_dispatchers = set(_Dispatchers.keys())
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
elif ep_size >= 16:
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
if "TokenDispatcherWithMC2" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
class MoETokenDispatcher(ABC): class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
@@ -484,18 +506,19 @@ class MoETokenDispatcher(ABC):
return get_ep_group().world_size return get_ep_group().world_size
@abstractmethod @abstractmethod
def token_dispatch( def token_dispatch(self,
self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor, row_idx: torch.Tensor,
expert_map: torch.Tensor, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None, shared_experts: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None, shared_gate_up: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None,
): mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
raise NotImplementedError("Dispatch function not implemented.") raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod @abstractmethod
@@ -516,40 +539,39 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size self.ep_world_size = get_mc2_group().world_size
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_dispatch_v2 = hasattr(torch_npu, self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2") "npu_moe_distribute_dispatch_v2")
self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 self.need_extra_args = (
or self.torchair_graph_enabled) get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \ self.a3_need_extra_args = \
get_ascend_soc_version() == AscendSocVersion.A3 get_ascend_soc_version() == AscendSocVersion.A3
self.output = None self.output = None
self.dynamic_scale = None
self.assist_info_for_combine = None self.assist_info_for_combine = None
self.ep_recv_counts = None self.ep_recv_counts = None
self.shared_act = None self.shared_act = None
self.topk_ids = None self.topk_ids = None
self.topk_weights = None self.topk_weights = None
self.shared_experts = None self.shared_experts = None
self.mc2_mask = None
def get_dispatch_mc2_kwargs(self, def get_dispatch_mc2_kwargs(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: torch.Tensor, topk_ids: torch.Tensor,
global_redundant_expert_num: int = 0): expert_map: torch.Tensor,
quant_mode = 0 global_redundant_expert_num: int = 0,
forward_context = get_forward_context() ):
mc2_mask = forward_context.mc2_mask
if self.with_quant: if self.with_quant:
quant_mode = 2
if (expert_map is not None): if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num moe_expert_num = len(expert_map) + global_redundant_expert_num
else: else:
moe_expert_num = global_redundant_expert_num moe_expert_num = global_redundant_expert_num
else: else:
quant_mode = 0
moe_expert_num = len(expert_map) moe_expert_num = len(expert_map)
kwargs_mc2 = { kwargs_mc2 = {
"x": hidden_states, "x": hidden_states,
@@ -575,28 +597,30 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
}) })
if self.a3_need_extra_args and self.enable_dispatch_v2: if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({ stage1_kwargs.update({
"x_active_mask": mc2_mask, "x_active_mask": self.mc2_mask,
}) })
kwargs_mc2.update(stage1_kwargs) kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2 return kwargs_mc2
def token_dispatch( def token_dispatch(self,
self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor, row_idx: torch.Tensor,
expert_map: torch.Tensor, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None, shared_experts: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None, shared_gate_up: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None,
): mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.expert_map = expert_map self.expert_map = expert_map
self.topk_ids = topk_ids self.topk_ids = topk_ids
self.topk_weights = topk_weights self.topk_weights = topk_weights
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.mc2_mask = mc2_mask
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map, topk_ids, expert_map,
@@ -606,28 +630,27 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2) **kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream()) # comm_stream.wait_stream(torch.npu.current_stream())
expand_x, self.dynamic_scale, self.assist_info_for_combine, \ expand_x, dynamic_scale, self.assist_info_for_combine, \
expert_token_nums, self.ep_recv_counts = self.output[0:5] expert_token_nums, self.ep_recv_counts = self.output[0:5]
if self.with_quant: if self.with_quant:
if shared_experts is not None: if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0): shared_act_out = shared_experts.act_fn(
npu_wait_tensor(shared_gate_up, expand_x) (shared_gate_up, shared_dequant_scale))
shared_act_out = shared_experts.act_fn( self.shared_act, self.swiglu_out_scale = \
(shared_gate_up, shared_dequant_scale)) shared_act_out[0], shared_act_out[1]
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
else: else:
if shared_experts is not None: if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0): shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
npu_wait_tensor(hidden_states, topk_weights) self.shared_act = shared_experts.act_fn(shared_gate_up)
shared_gate_up, _ = shared_experts.gate_up_proj(
hidden_states)
npu_wait_tensor(shared_gate_up, expand_x)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 1 group_list_type = 1
return group_list_type, expand_x, expert_token_nums return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None assert self.expert_map is not None
@@ -635,8 +658,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
assert self.topk_ids is not None assert self.topk_ids is not None
assert self.output is not None assert self.output is not None
moe_expert_num = len(self.expert_map) moe_expert_num = len(self.expert_map)
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
# moeCombine # moeCombine
kwargs_mc2 = { kwargs_mc2 = {
"expand_x": hidden_states, "expand_x": hidden_states,
@@ -677,7 +698,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
}) })
if self.a3_need_extra_args and self.enable_dispatch_v2: if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({ stage3_kwargs.update({
"x_active_mask": mc2_mask, "x_active_mask": self.mc2_mask,
}) })
kwargs_mc2.update(stage3_kwargs) kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2 return kwargs_mc2
@@ -685,7 +706,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def token_combine(self, def token_combine(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
bias: torch.Tensor = None): bias: torch.Tensor = None):
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2( hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2 **kwargs_mc2
@@ -695,15 +715,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
return hidden_states return hidden_states
else: else:
if self.with_quant: if self.with_quant:
with npu_stream_switch("moe_secondary", 0): shared_hidden_states, _ = self.shared_experts.down_proj(
npu_wait_tensor(self.shared_act, hidden_states) (self.shared_act, self.swiglu_out_scale))
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
else: else:
with npu_stream_switch("moe_secondary", 0): shared_hidden_states, _ = self.shared_experts.down_proj(
npu_wait_tensor(self.shared_act, hidden_states) self.shared_act)
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
return hidden_states, shared_hidden_states return hidden_states, shared_hidden_states
@@ -711,13 +727,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.apply_router_weight_on_input = kwargs.get( self.apply_router_weight_on_input = False
"apply_router_weight_on_input")
self.top_k = kwargs.get("top_k")
self.max_num_tokens = kwargs.get("max_num_tokens") self.max_num_tokens = kwargs.get("max_num_tokens")
ep_size = kwargs.get("ep_size") self.num_experts_local = kwargs.get("num_local_experts", 0)
if ep_size is not None:
self.num_experts_local = self.num_experts // ep_size
self.sorted_weights = None self.sorted_weights = None
self.expanded_row_idx = None self.expanded_row_idx = None
self.sorted_token_indices = None self.sorted_token_indices = None
@@ -727,20 +739,20 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.topk_weights = None self.topk_weights = None
self.topk_ids = None self.topk_ids = None
def token_dispatch( def token_dispatch(self,
self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor, row_idx: torch.Tensor,
expert_map: torch.Tensor, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None, shared_experts: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None, shared_gate_up: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None,
): mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.original_shape = hidden_states.shape self.original_shape = hidden_states.shape
# assert len(original_shape) == 2
num_tokens = hidden_states.shape[:-1].numel() num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype dtype = hidden_states.dtype
@@ -748,9 +760,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.expert_map = expert_map self.expert_map = expert_map
self.topk_weights = topk_weights self.topk_weights = topk_weights
self.topk_ids = topk_ids self.topk_ids = topk_ids
# assert dtype in [torch.float32, torch.float16, torch.bfloat16 self.apply_router_weight_on_input = apply_router_weight_on_input
# ], "Only float32, float16, and bfsloat16 are supported"
if self.apply_router_weight_on_input: if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2 assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)" ), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -803,19 +813,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
sorted_hidden_states = hidden_states[self.sorted_token_indices] sorted_hidden_states = hidden_states[self.sorted_token_indices]
if self.with_quant: if self.with_quant:
group_list_type = 1 group_list_type = 1
expert_tokens = token_counts
else: else:
expert_tokens = torch.cumsum(token_counts, expert_tokens = torch.cumsum(token_counts,
dim=0, dim=0,
dtype=torch.int64) dtype=torch.int64)
group_list_type = 0 group_list_type = 0
else: else:
row_idx_len = num_tokens * self.top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(self.top_k,
-1).permute(
1, 0).contiguous())
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states, hidden_states,
@@ -827,18 +831,23 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expanded_expert_idx, self.num_experts_local) expanded_expert_idx, self.num_experts_local)
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0 group_list_type = 0
return group_list_type, sorted_hidden_states, expert_tokens return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
}
def token_combine(self, def token_combine(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
bias: torch.Tensor = None): bias: torch.Tensor = None):
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
assert self.original_shape is not None assert self.original_shape is not None
dtype = hidden_states.dtype dtype = hidden_states.dtype
device = hidden_states.device device = hidden_states.device
if self.expert_map is not None: if self.expert_map is not None:
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
weighted_down_out = hidden_states * \ weighted_down_out = hidden_states * \
self.sorted_weights.unsqueeze(1) self.sorted_weights.unsqueeze(1)
@@ -887,7 +896,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expanded_src_to_dst_row=self.expanded_row_idx, expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids, export_for_source_row=self.topk_ids,
) )
return final_hidden_states return final_hidden_states
@@ -895,29 +903,27 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(MoETokenDispatcher, self).__init__(**kwargs) super().__init__(**kwargs)
self.apply_router_weight_on_input = kwargs.get( self.apply_router_weight_on_input = False
"apply_router_weight_on_input") self.local_ep = 1
ep_size = kwargs.get("ep_size")
self.local_ep = ep_size
assert self.local_ep is not None
self.local_num_experts = self.num_experts // self.local_ep self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep self.local_num_group = self.top_k // self.local_ep
self.bsz = None self.bsz = None
def token_dispatch( def token_dispatch(self,
self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor, row_idx: torch.Tensor,
expert_map: torch.Tensor, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None, shared_experts: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None, shared_gate_up: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None,
): mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input: if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2 assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)" ), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -932,7 +938,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
flatten_topk_ids = topk_ids.view(-1) flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32) self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
self.sorted_hidden_states = hidden_states.index_select( sorted_hidden_states = hidden_states.index_select(
0, self.sorted_topk_ids // self.local_num_group) 0, self.sorted_topk_ids // self.local_num_group)
experts_id = torch.arange(0, experts_id = torch.arange(0,
@@ -942,15 +948,20 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
num_tokens_per_expert = ( num_tokens_per_expert = (
flatten_topk_ids.unsqueeze(-1) == experts_id).to( flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0) torch.float32).sum(0)
self.topk_scales = topk_weights.view(-1).index_select( topk_scales = topk_weights.view(-1).index_select(
0, self.sorted_topk_ids).unsqueeze(-1) 0, self.sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
return hidden_states, group_list group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales,
}
def token_combine(self, def token_combine(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
bias: torch.Tensor = None): bias: torch.Tensor = None):
assert self.local_ep is not None
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32) torch.int32)
unsorted_hidden_states = hidden_states.index_select( unsorted_hidden_states = hidden_states.index_select(
@@ -1009,18 +1020,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] - self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous" 1), "local_expert_indices must be continuous"
def token_dispatch( def token_dispatch(self,
self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor, row_idx: torch.Tensor,
expert_map: torch.Tensor, expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None, shared_experts: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None, shared_gate_up: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None,
): mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"

View File

@@ -26,9 +26,8 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all,
fused_experts_with_mc2)
class AscendW4A8DynamicLinearMethod: class AscendW4A8DynamicLinearMethod:
@@ -291,48 +290,25 @@ class AscendW4A8DynamicFusedMoEMethod:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2( return unified_fused_experts_eager(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second, w1_scale=layer.w13_weight_scale_second,
w2_scale=layer.w2_weight_scale_second, w2_scale=layer.w2_weight_scale_second,
w1_scale_bias=layer.w13_scale_bias, w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias, w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
top_k=top_k, row_idx=row_idx,
expert_map=expert_map, expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name, log2phy=log2phy,
log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num,
global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts,
shared_experts=shared_experts, shared_gate_up=shared_gate_up,
quantized_x_for_share=shared_gate_up, shared_dequant_scale=shared_dequant_scale,
dynamic_scale_for_share=shared_dequant_scale, mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask", None))
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into layers module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second,
w2_scale=layer.w2_weight_scale_second,
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale): def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
group_num, k, n = weight.shape group_num, k, n = weight.shape

View File

@@ -18,17 +18,16 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
import torch_npu import torch_npu
from vllm.distributed import GroupCoordinator, get_ep_group from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
dispose_tensor, get_ascend_soc_version)
def apply_mlp_decode(hidden_states: torch.Tensor, def apply_mlp_decode(hidden_states: torch.Tensor,
@@ -197,520 +196,6 @@ def apply_mlp(hidden_states: torch.Tensor,
return hidden_states return hidden_states
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert mc2_mask is not None
if log2phy is not None:
topk_ids = log2phy[topk_ids]
quant_mode = 2
ep_group = get_mc2_group()
ep_rank_id = ep_group.rank_in_group
ep_world_size = ep_group.world_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
# hidden_states = hidden_states.bfloat16()
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
# `expand_x` will be disposed in the `apply_mlp` function
if w1_scale_bias is None:
down_out_list = apply_mlp_decode(expand_x,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
else:
# w4a8 scene, cannot use apply_mlp_decode because the operator is not supported
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias)
# moeCombine
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None:
return hidden_states
else:
shared_output, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
return hidden_states, shared_output
def init_routing_quant(hidden_states, top_k, topk_ids, row_idx,
global_num_experts):
num_tokens, _ = hidden_states.shape
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
1, 0).contiguous().view(-1))
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
global_expert_tokens = global_expert_tokens.to(torch.int32)
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
):
if log2phy is not None:
topk_ids = log2phy[topk_ids]
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
if expert_map is not None:
global_num_experts = len(expert_map) + global_redundant_expert_num
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
hidden_states,
expert_idx=topk_ids.to(torch.int32),
active_num=0,
expert_capacity=0,
expert_num=global_num_experts,
drop_pad_mode=0,
expert_tokens_num_mode=2,
expert_tokens_before_capacity_flag=False,
quant_mode=1,
)
else:
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
hidden_states, top_k, topk_ids, row_idx, global_num_experts)
gather_sizes = global_expert_tokens.new_empty(
global_expert_tokens.shape[0])
dist.all_to_all_single(gather_sizes, global_expert_tokens)
token_counts_combined = torch.stack(
[gather_sizes, global_expert_tokens], dim=0)
token_counts_combined = token_counts_combined.view(
2, ep_group.world_size, -1).sum(dim=2)
token_counts_combined_cpu = token_counts_combined.to(
torch.device("cpu"), non_blocking=True).numpy()
all_tokens = gather_sizes.sum()
gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
quantized_tokens.shape[1])
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
gather_size_list = token_counts_combined_cpu[1]
scatter_size_list = token_counts_combined_cpu[0]
dist.all_to_all_single(gathered_tokens, quantized_tokens,
scatter_size_list, gather_size_list)
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
gather_size_list)
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
gathered_tokens,
gather_sizes.view(ep_group.world_size, -1),
per_token_scales=dynamic_scale)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1
else:
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
dynamic_scale = None
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(
hidden_states,
w1,
w1_scale, #17
w2,
w2_scale,
expert_tokens, #16
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias)
if expert_map is not None:
reordered_outputs = torch.index_select(
hidden_states,
dim=0,
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
dist.all_to_all_single(hidden_states, reordered_outputs,
gather_size_list, scatter_size_list)
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=None,
drop_pad_mode=2)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
def fused_experts_with_allgather(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
batch_size, hidden_size = hidden_states.shape
topk_weights = topk_weights.to(hidden_states.dtype)
ep_group = get_ep_group().device_group
ep_rank = torch.distributed.get_rank(group=ep_group)
ep_size = torch.distributed.get_world_size(ep_group)
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_size
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=pertoken_scale,
offset=None,
active_num=num_tokens * top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
],
quant_mode=-1,
row_idx_type=1)
group_list_type = 1
sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
expanded_x_idx)
row_index = expanded_x_idx // topk_ids.shape[-1]
row_index = row_index.to(torch.int64)
share_input = torch.zeros((batch_size, hidden_size),
dtype=torch.bfloat16,
device="npu")
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)
final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
hidden_states,
w2,
scale=w2_scale.to(torch.float32),
bias=None,
pertoken_scale=pertoken_scale.view(-1),
group_list=expert_tokens,
shared_input=share_input,
logit=sorted_topk_weight.to(torch.float32),
row_index=row_index,
output_bs=batch_size).to(torch.bfloat16)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
dtype = hidden_states.dtype
device = hidden_states.device
if expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(num_tokens,
device=device,
dtype=torch.int64).unsqueeze(1).expand(
-1, top_k).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = expert_map[experts_flat]
# Filter valid token-expert pairs
mask = local_experts_flat != -1
filtered_weights = torch.where(
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
filtered_experts = torch.where(
mask, local_experts_flat,
torch.full_like(local_experts_flat,
num_experts)).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts)
sorted_token_indices = token_indices[sort_indices]
sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
hidden_states = hidden_states[sorted_token_indices]
group_list_type = 1
else:
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(hidden_states,
w1,
w1_scale,
w2,
w2_scale,
expert_tokens,
group_list_type=group_list_type)
if expert_map is not None:
hidden_states.mul_(sorted_weights.unsqueeze(1))
final_hidden_states = torch.zeros(*original_shape,
device=device,
dtype=dtype)
num_valid_tokens = mask.sum()
valid_token_mask = torch.arange(
0, sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
0).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
class AscendW8A8DynamicLinearMethod: class AscendW8A8DynamicLinearMethod:
"""Linear method for Ascend W8A8_DYNAMIC. """Linear method for Ascend W8A8_DYNAMIC.
""" """
@@ -905,68 +390,23 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather( return unified_fused_experts_eager(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight, w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
top_k=top_k, row_idx=row_idx,
expert_map=expert_map) expert_map=expert_map,
elif fused_moe_state == FusedMoEState.MC2: log2phy=log2phy,
return fused_experts_with_mc2( global_redundant_expert_num=global_redundant_expert_num,
hidden_states=x, shared_experts=shared_experts,
w1=layer.w13_weight, shared_gate_up=shared_gate_up,
w2=layer.w2_weight, shared_dequant_scale=shared_dequant_scale,
w1_scale=layer.w13_weight_scale_fp32, mc2_mask=kwargs.get("mc2_mask", None))
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale)
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map)
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into layers module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
)
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.transpose_weight: if self.transpose_weight: