[Refactor][MoE] remove redundant code after refactoring fused_moe (#2612)
### What this PR does / why we need it?
There are a lot of redundant codes related to moe here, and the
structure is not very clear.
We did the following things:
we have placed the relatively independent code related to apply_mlp into
a separate file;
removed the environment variables of alltoall_buffer and alltoall_seq.
Remove the code related to alltoall_buffer and alltoall_seq, and retain
the sole TokenDispatcher inheritance class.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e&ut
- vLLM version: v0.10.1.1
- vLLM main:
4071c76cf3
---------
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test.yaml
vendored
1
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -279,7 +279,6 @@ jobs:
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
|
||||
|
||||
@@ -108,14 +108,13 @@ def test_models_distributed_pangu():
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
with VllmRunner(snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
enable_expert_parallel=True) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@@ -141,28 +140,6 @@ def test_models_distributed_topk() -> None:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"})
|
||||
def test_models_distributed_alltoallv() -> None:
|
||||
example_prompts = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
]
|
||||
dtype = "half"
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
temperature=0.0,
|
||||
top_k=50,
|
||||
top_p=0.9)
|
||||
|
||||
with VllmRunner(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_W8A8():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
69
tests/ut/ops/test_common_fused_moe.py
Normal file
69
tests/ut/ops/test_common_fused_moe.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
|
||||
|
||||
|
||||
class TestFusedExpertsMoGE(TestBase):
|
||||
|
||||
def test_fused_experts_moge(self):
|
||||
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
|
||||
patch('torch_npu.npu_swiglu') as mock_swiglu, \
|
||||
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
|
||||
torch.randn(x[0].shape[0], weight[0].shape[1])
|
||||
]
|
||||
|
||||
mock_swiglu.side_effect = lambda x: x
|
||||
|
||||
hidden_states = torch.randn(4, 128)
|
||||
w1 = torch.randn(4, 256, 128)
|
||||
w2 = torch.randn(4, 128, 128)
|
||||
topk_weights = torch.rand(4, 1)
|
||||
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
|
||||
top_k = 1
|
||||
global_num_experts = 4
|
||||
|
||||
moe_parallel_config = type(
|
||||
'MockConfig', (), {
|
||||
'ep_size': 1,
|
||||
'tp_size': 1,
|
||||
'dp_size': 1,
|
||||
'tp_rank': 0,
|
||||
'dp_rank': 0,
|
||||
'ep_rank': 0,
|
||||
'use_ep': True
|
||||
})()
|
||||
|
||||
output = fused_experts_moge(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (4, 128))
|
||||
@@ -27,9 +27,9 @@ from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod,
|
||||
unified_apply_mlp)
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -129,36 +129,38 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers):
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
|
||||
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
@@ -441,12 +443,11 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[16, False], [1, True], [1, False], [4, False]])
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
ep_size, alltoall_buffer = others_param
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
if ep_size == 1:
|
||||
@@ -464,9 +465,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
||||
alltoall_buffer), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
@@ -475,8 +474,6 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
if alltoall_buffer:
|
||||
moe_method.max_model_len = 1
|
||||
layer = MagicMock()
|
||||
|
||||
local_num_experts = 2
|
||||
@@ -529,9 +526,8 @@ class TestExpertsSelector:
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.get_mc2_group')
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
@@ -539,16 +535,12 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p,
|
||||
mock_get_mc2_group,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_mc2_group = MagicMock()
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
@@ -601,7 +593,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -643,7 +635,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -703,7 +695,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
|
||||
@@ -17,57 +17,13 @@
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase, TestBase
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher,
|
||||
get_token_dispatcher, setup_token_dispatchers)
|
||||
|
||||
|
||||
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
config = MoEDispatcherConfig()
|
||||
config.set_num_local_experts(2)
|
||||
config.set_num_moe_experts(4)
|
||||
config.set_moe_pad_expert_input_to_capacity(False)
|
||||
config.set_moe_expert_capacity_factor(None)
|
||||
config.set_moe_router_topk(2)
|
||||
config.set_moe_grouped_gemm(False)
|
||||
config.set_group_topk(0)
|
||||
config.set_num_groups(1)
|
||||
config.set_is_fused(False)
|
||||
return config.build()
|
||||
|
||||
def mock_ep_group(self, mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mock_group.device_group = "mock_group"
|
||||
return mock_group
|
||||
|
||||
@pytest.fixture
|
||||
def dispatcher(self, config, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
|
||||
return_value=self.mock_ep_group(mocker))
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock)
|
||||
return MoEAlltoAllSeqOverLapDispatcher(config)
|
||||
|
||||
def test_initialization(self, dispatcher, config):
|
||||
assert dispatcher.num_local_experts == config.num_local_experts
|
||||
assert dispatcher.num_experts == config.num_moe_experts
|
||||
assert dispatcher.local_expert_indices == [0, 1]
|
||||
assert dispatcher.ep_rank == 0
|
||||
assert dispatcher.ep_size == 2
|
||||
assert dispatcher.overlap_stream is not None
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
@@ -353,8 +353,7 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[16, False], [1, True], [1, False], [4, False]])
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
@@ -363,13 +362,11 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
||||
3 test use_select_experts and fused_experts_with_all2all
|
||||
4 test use_select_experts and fused_experts
|
||||
"""
|
||||
ep_size, alltoall_buffer = others_param
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
forward_context = MagicMock(
|
||||
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
||||
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.MOE_ALL2ALL_BUFFER",
|
||||
alltoall_buffer), \
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
|
||||
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
@@ -377,8 +374,6 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
if alltoall_buffer:
|
||||
moe_method.max_model_len = 1
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
|
||||
@@ -35,10 +35,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
return FusedMoEState.NaiveMulticast
|
||||
else:
|
||||
return FusedMoEState.AllGather
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
|
||||
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
|
||||
return (FusedMoEState.All2AllSeq if
|
||||
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
return FusedMoEState.All2All
|
||||
|
||||
@@ -90,11 +90,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
||||
),
|
||||
# MOE_ALL2ALL_BUFFER:
|
||||
# 0: default, normal init.
|
||||
# 1: enable moe_all2all_buffer.
|
||||
"MOE_ALL2ALL_BUFFER":
|
||||
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
|
||||
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
||||
# training, the optimized model may not be suitable. In this case, set this
|
||||
# value to False to disable the optimized model.
|
||||
@@ -136,11 +131,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# this feature is supported in A2, and eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
# Whether to enable the alltoall_seq flag, this provides a basic framework on the basis of alltoall for easy expansion.
|
||||
# 0: default, normal init.
|
||||
# 1: enable moe all2all seq.
|
||||
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
|
||||
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
|
||||
# Whether to enable mlp optimize when tensor parallel is enabled.
|
||||
# this feature in eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
|
||||
|
||||
@@ -22,6 +22,8 @@ import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
|
||||
@@ -30,7 +32,6 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import fused_experts_moge
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
@@ -139,6 +140,95 @@ def fused_experts(
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
@@ -46,397 +45,12 @@ from vllm_ascend.distributed.communication_op import \
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
max_row_per_ep_rank: int, num_tokens: int,
|
||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
original_total_elements = num_tokens * top_k
|
||||
device = topk_ids.device
|
||||
original_dtype = topk_ids.dtype
|
||||
|
||||
if original_total_elements == 0:
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
unpad_indices = torch.full((original_total_elements, ),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
experts_per_ep_rank_val = expert_num // ep_size
|
||||
if experts_per_ep_rank_val == 0:
|
||||
raise ValueError(
|
||||
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
||||
"Ensure expert_num >= ep_size.")
|
||||
|
||||
assigned_ep_rank = (topk_ids.float() /
|
||||
experts_per_ep_rank_val).to(original_dtype)
|
||||
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
||||
|
||||
is_new_segment = torch.cat(
|
||||
(torch.tensor([True], device=device), assigned_ep_rank[1:]
|
||||
!= assigned_ep_rank[:-1]))
|
||||
temp_start_markers = torch.full_like(indices_arange,
|
||||
-1,
|
||||
dtype=indices_arange.dtype)
|
||||
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
||||
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
||||
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
||||
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
||||
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
||||
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
||||
unpad_indices = torch.where(
|
||||
is_kept_mask, indices_in_rec_cond_list_for_all,
|
||||
torch.tensor(-1, device=device, dtype=torch.long))
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
if topk_ids.shape[0] > 0:
|
||||
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
||||
temp_pad_buffer = torch.full((output_len + 1, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
output_len_tensor = torch.tensor(output_len,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
||||
output_len_tensor)
|
||||
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
||||
topk_ids_pad = temp_pad_buffer[:output_len]
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
Args:
|
||||
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: expert weights2 with shape
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
group_list: number of tokens for each expert, follow cumsum mode, and
|
||||
with shape (num_experts).
|
||||
transpose_weight:
|
||||
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: (num_experts, hidden_size, intermediate_size) ->
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
Returns:
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
|
||||
|
||||
def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -742,24 +356,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
|
||||
self.quant_method, AscendUnquantizedFusedMoEMethod):
|
||||
self.reduce_results = False
|
||||
moe_dispatcher_config = (
|
||||
MoEDispatcherConfig().set_num_moe_experts(
|
||||
self.global_num_experts).set_num_local_experts(
|
||||
self.local_num_experts).set_moe_router_topk(
|
||||
top_k).set_group_topk(topk_group).
|
||||
set_num_groups(num_expert_group).set_expert_bias(
|
||||
e_score_correction_bias).set_scaling_factor(1.0).build())
|
||||
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
self.token_dispatchers = [
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
@@ -29,434 +29,11 @@ import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
|
||||
all_to_all_sp2hp, gather_from_sequence_parallel_region,
|
||||
reduce_scatter_last_dim_to_tensor_parallel_region)
|
||||
from vllm_ascend.distributed.tensor_parallel import \
|
||||
gather_from_sequence_parallel_region
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoEDispatcherConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.num_local_experts: int = 0
|
||||
self.num_moe_experts: int = 0
|
||||
self.moe_pad_expert_input_to_capacity: bool = False
|
||||
self.moe_expert_capacity_factor: Optional[float] = None
|
||||
self.moe_router_topk: int = 2
|
||||
self.moe_grouped_gemm: bool = False
|
||||
self.group_topk: int = 0
|
||||
self.num_groups: int = 1
|
||||
self.expert_bias: torch.Tensor = None
|
||||
self.scaling_factor: Optional[float] = None
|
||||
self.is_fused: bool = True
|
||||
|
||||
def set_num_local_experts(self, num_local_experts):
|
||||
self.num_local_experts = num_local_experts
|
||||
return self
|
||||
|
||||
def set_num_moe_experts(self, num_moe_experts):
|
||||
self.num_moe_experts = num_moe_experts
|
||||
return self
|
||||
|
||||
def set_moe_pad_expert_input_to_capacity(self,
|
||||
moe_pad_expert_input_to_capacity):
|
||||
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity
|
||||
return self
|
||||
|
||||
def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor):
|
||||
self.moe_expert_capacity_factor = moe_expert_capacity_factor
|
||||
return self
|
||||
|
||||
def set_moe_router_topk(self, moe_router_topk):
|
||||
self.moe_router_topk = moe_router_topk
|
||||
return self
|
||||
|
||||
def set_moe_grouped_gemm(self, moe_grouped_gemm):
|
||||
self.moe_grouped_gemm = moe_grouped_gemm
|
||||
return self
|
||||
|
||||
def set_group_topk(self, group_topk):
|
||||
self.group_topk = group_topk
|
||||
return self
|
||||
|
||||
def set_num_groups(self, num_groups):
|
||||
self.num_groups = num_groups
|
||||
return self
|
||||
|
||||
def set_expert_bias(self, expert_bias):
|
||||
self.expert_bias = expert_bias
|
||||
return self
|
||||
|
||||
def set_scaling_factor(self, scaling_factor):
|
||||
self.scaling_factor = scaling_factor
|
||||
return self
|
||||
|
||||
def set_is_fused(self, is_fused):
|
||||
self.is_fused = is_fused
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self
|
||||
|
||||
|
||||
class MoEDispatcher:
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.config = config
|
||||
self.shared_experts = None
|
||||
|
||||
def set_shared_experts(self, shared_experts):
|
||||
self.shared_experts = shared_experts
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@property
|
||||
def tp_ep_group(self):
|
||||
"""Get expert tensor and model parallel group."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_ep_size(self):
|
||||
return 1
|
||||
|
||||
|
||||
class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
|
||||
overlap_stream = None
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig):
|
||||
"""
|
||||
Initialize the AlltoAllSeq token dispatcher.
|
||||
|
||||
Args:
|
||||
config (MoEDispatcherConfig): Configuration for the transformer model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.num_local_experts = config.num_local_experts
|
||||
self.config = config
|
||||
# use MOEAlltoAllSEQTokenDispatcher to init
|
||||
|
||||
self.hidden_shape = None
|
||||
self.num_input_tokens = None
|
||||
self.num_experts = config.num_moe_experts
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
self.probs = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.routing_map = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# A cuda stream synchronization is needed in self.token_permutation()
|
||||
# in some cases, because there are several non-blocking DtoH data
|
||||
# transfers called in self.preprocess(). The synchronization happens
|
||||
# at different points based on MoE settings as late as possible.
|
||||
# Valid sync points are "before_permutation_1", "before_ep_alltoall",
|
||||
# "before_finish", and "no_sync".
|
||||
self.device_sync_point = "no_sync"
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.cached_permutated_local_input_tokens = None
|
||||
self.cached_global_input_tokens = None
|
||||
self.cached_shared_expert_output = None
|
||||
self.tokens_per_expert = None
|
||||
self.perm1_finish_event = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None:
|
||||
MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream()
|
||||
|
||||
self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream
|
||||
|
||||
def preprocess(self,
|
||||
indices: torch.Tensor,
|
||||
with_sync=True) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess routing map for AlltoAll communication and token permutation.
|
||||
This method computes the number of tokens assigned to each expert based on
|
||||
the routing map. It also initializes the necessary data structures for
|
||||
AlltoAll communication, such as input and output splits, and the mapping
|
||||
between global tokens and local experts.
|
||||
|
||||
Args:
|
||||
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
|
||||
[num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
|
||||
"""
|
||||
num_local_tokens_per_expert = torch.histc(indices,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
# num_local_tokens_per_expert: [num_experts]
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = indices.numel()
|
||||
if self.ep_size > 1 or self.num_local_experts > 1:
|
||||
# Token dropless and enable ep. A synchronization is needed before expert parallel
|
||||
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
|
||||
self.device_sync_point = "before_ep_alltoall"
|
||||
else:
|
||||
# Token dropless and no ep. A synchronization is needed to get the
|
||||
# `tokens_per_expert` CPU value.
|
||||
self.device_sync_point = "before_finish"
|
||||
|
||||
if ep_size > 1:
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size, self.num_local_experts).sum(axis=1).to(
|
||||
torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum."
|
||||
)
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
else:
|
||||
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
|
||||
-1, self.num_experts)
|
||||
num_tokens_per_local_expert = num_local_tokens_per_expert
|
||||
|
||||
if self.num_local_experts > 1 and with_sync:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.device_sync_point = "no_sync"
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
routing_map: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Dispatch tokens to local experts using AlltoAllSeq communication.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input token embeddings.
|
||||
probs (torch.Tensor): Probs of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
routing_map (torch.Tensor): Mapping of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- Permuted token embeddings for local experts.
|
||||
- Number of tokens per expert.
|
||||
"""
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.probs = probs
|
||||
self.top_indices = routing_map
|
||||
assert probs.dim() == 2, "Expected 2D tensor for probs"
|
||||
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
# Permutation 1: input to AlltoAll input
|
||||
def alltoall_token_permutation1(hidden_states, routing_map):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self.preprocess(routing_map)
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = all_to_all_sp2hp(hidden_states,
|
||||
group=self.tp_ep_group)
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
if self.device_sync_point == "before_permutation_1":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=self.top_indices,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1(
|
||||
hidden_states, routing_map)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
# permute 1
|
||||
|
||||
ep_group = self.ep_group
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
if self.device_sync_point == "before_ep_alltoall":
|
||||
torch.npu.current_stream().synchronize()
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
ep_group,
|
||||
)
|
||||
|
||||
# shared experts compute
|
||||
if self.shared_experts is not None:
|
||||
(share_experts_output), *_ = self.shared_experts(hidden_states)
|
||||
else:
|
||||
share_experts_output = None
|
||||
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_permutation2(global_input_tokens):
|
||||
# Permutation 2: Sort tokens by local expert.
|
||||
if self.num_local_experts > 1:
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
|
||||
# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
|
||||
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
|
||||
if self.tp_ep_size > 1 and self.config.moe_grouped_gemm:
|
||||
global_input_tokens = all_gather_last_dim_from_tensor_parallel_region(
|
||||
global_input_tokens, self.tp_ep_group)
|
||||
if self.device_sync_point == "before_finish":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
return global_input_tokens
|
||||
|
||||
# token premute2 input
|
||||
global_input_tokens = alltoall_token_permutation2(global_input_tokens)
|
||||
|
||||
return share_experts_output, global_input_tokens, tokens_per_expert
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
"""
|
||||
Reverse the token permutation to restore the original order.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Output from local experts.
|
||||
bias (torch.Tensor, optional): Bias tensor (not supported).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- Unpermuted token embeddings in the original order.
|
||||
- None (bias is not supported).
|
||||
"""
|
||||
|
||||
def alltoall_token_unpermutation1(hidden_states):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
|
||||
# Perform tensor parallel Reduce-Scatter
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region(
|
||||
hidden_states, group=self.tp_ep_group)
|
||||
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states,
|
||||
self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
hidden_states = alltoall_token_unpermutation1(hidden_states)
|
||||
|
||||
ep_group = self.ep_group
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits, ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_unpermutation2(permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.
|
||||
to(torch.int32),
|
||||
probs=self.probs,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Perform tensor parallel AlltoAll communication
|
||||
# output: [S*B, H/TP] -> [S*B/TP, H]
|
||||
if self.tp_ep_size > 1:
|
||||
output = all_to_all_hp2sp(output, self.tp_ep_group)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
|
||||
output = alltoall_token_unpermutation2(permutated_local_input_tokens)
|
||||
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
|
||||
return output, None
|
||||
|
||||
|
||||
_Dispatchers: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -1090,7 +667,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
hidden_states = self._combine_preprocess(hidden_states)
|
||||
|
||||
|
||||
@@ -38,15 +38,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
@@ -54,74 +51,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_ascend_soc_version,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
def torchair_process_topk_ids(topk_ids: torch.Tensor, expert_num: int,
|
||||
ep_size: int, max_row_per_ep_rank: int,
|
||||
num_tokens: int,
|
||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
original_total_elements = num_tokens * top_k
|
||||
device = topk_ids.device
|
||||
original_dtype = topk_ids.dtype
|
||||
|
||||
if original_total_elements == 0:
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
unpad_indices = torch.full((original_total_elements, ),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
experts_per_ep_rank_val = expert_num // ep_size
|
||||
if experts_per_ep_rank_val == 0:
|
||||
raise ValueError(
|
||||
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
||||
"Ensure expert_num >= ep_size.")
|
||||
|
||||
assigned_ep_rank = (topk_ids.float() /
|
||||
experts_per_ep_rank_val).to(original_dtype)
|
||||
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
||||
|
||||
is_new_segment = torch.cat(
|
||||
(torch.tensor([True], device=device), assigned_ep_rank[1:]
|
||||
!= assigned_ep_rank[:-1]))
|
||||
temp_start_markers = torch.full_like(indices_arange,
|
||||
-1,
|
||||
dtype=indices_arange.dtype)
|
||||
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
||||
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
||||
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
||||
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
||||
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
||||
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
||||
unpad_indices = torch.where(
|
||||
is_kept_mask, indices_in_rec_cond_list_for_all,
|
||||
torch.tensor(-1, device=device, dtype=torch.long))
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
if topk_ids.shape[0] > 0:
|
||||
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
||||
temp_pad_buffer = torch.full((output_len + 1, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
output_len_tensor = torch.tensor(output_len,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
||||
output_len_tensor)
|
||||
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
||||
topk_ids_pad = temp_pad_buffer[:output_len]
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
|
||||
def torchair_fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -459,130 +388,6 @@ def torchair_fused_experts_with_all2all(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def torchair_fused_experts_with_all2all_buffer(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
max_model_len: int,
|
||||
global_batch_size: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
):
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
device = hidden_states.device
|
||||
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
|
||||
device=device).view(top_k,
|
||||
-1).permute(1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
|
||||
max_model_len // ep_group.world_size +
|
||||
1) * top_k * 2
|
||||
expert_idx_buffer_scatter, unpad_indices = torchair_process_topk_ids(
|
||||
expanded_expert_idx, global_num_experts, ep_group.world_size,
|
||||
max_row_per_ep_rank, num_tokens, top_k)
|
||||
hidden_states_pad_idx = torch.zeros(
|
||||
expert_idx_buffer_scatter.shape,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=expert_idx_buffer_scatter.device)
|
||||
non_pad_len = torch.sum((expert_idx_buffer_scatter
|
||||
!= global_num_experts).to(torch.int32))
|
||||
hidden_states_pad_idx[expert_idx_buffer_scatter !=
|
||||
global_num_experts] = torch.arange(
|
||||
non_pad_len,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
|
||||
expert_idx_buffer_gather = torch.empty_like(
|
||||
expert_idx_buffer_scatter,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=expert_idx_buffer_scatter.device)
|
||||
hidden_states_buffer_gather = torch.empty_like(
|
||||
hidden_states_buffer_scatter,
|
||||
dtype=hidden_states_buffer_scatter.dtype,
|
||||
device=hidden_states_buffer_scatter.device)
|
||||
dist.all_to_all_single(expert_idx_buffer_gather,
|
||||
expert_idx_buffer_scatter,
|
||||
group=ep_group.device_group)
|
||||
dist.all_to_all_single(hidden_states_buffer_gather,
|
||||
hidden_states_buffer_scatter,
|
||||
group=ep_group.device_group)
|
||||
mask = expert_idx_buffer_gather != global_num_experts
|
||||
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
|
||||
global_num_experts // ep_group.world_size)
|
||||
hidden_states = hidden_states_buffer_gather[mask]
|
||||
idx_type = local_expert_idx.dtype
|
||||
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
|
||||
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states = torchair_apply_mlp(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
|
||||
hidden_states = hidden_states[resorted_idx]
|
||||
hidden_states_scatter = torch.zeros(
|
||||
(mask.shape[0], hidden_states.shape[1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
hidden_states_scatter[mask] = hidden_states
|
||||
hidden_states_gatter = torch.empty_like(
|
||||
hidden_states_scatter,
|
||||
dtype=hidden_states_scatter.dtype,
|
||||
device=hidden_states_scatter.device)
|
||||
dist.all_to_all_single(hidden_states_gatter,
|
||||
hidden_states_scatter,
|
||||
group=ep_group.device_group)
|
||||
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
|
||||
global_num_experts]
|
||||
if hidden_states_gatter.shape[0] != row_idx_len:
|
||||
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
hidden_states[unpad_indices != -1] = hidden_states_gatter
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
hidden_states = hidden_states_gatter
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def torchair_fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -674,25 +479,6 @@ def torchair_fused_experts_moge(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def torchair_fused_experts_with_all2allv(
|
||||
token_dispatcher,
|
||||
probs,
|
||||
routing_map,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
):
|
||||
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
|
||||
(share_experts_output, dispatched_input,
|
||||
tokens_per_expert) = (token_dispatcher.token_permutation(
|
||||
hidden_states, probs, routing_map))
|
||||
|
||||
expert_output = torchair_apply_mlp(dispatched_input, w1, w2,
|
||||
tokens_per_expert)
|
||||
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
|
||||
return output
|
||||
|
||||
|
||||
def torchair_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -1120,28 +906,6 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
elif MOE_ALL2ALL_BUFFER:
|
||||
return torchair_fused_experts_with_all2all_buffer(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
max_model_len=self.max_model_len,
|
||||
global_batch_size=self.global_batch_size,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
elif fused_moe_state == FusedMoEState.All2AllSeq:
|
||||
token_dispatcher = kwargs.get("token_dispatcher")
|
||||
return torchair_fused_experts_with_all2allv(
|
||||
token_dispatcher=token_dispatcher,
|
||||
probs=topk_weights,
|
||||
routing_map=topk_ids,
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
)
|
||||
else:
|
||||
return torchair_fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
@@ -1315,25 +1079,6 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
|
||||
self.quant_method, TorchairAscendUnquantizedFusedMoEMethod):
|
||||
self.reduce_results = False
|
||||
moe_dispatcher_config = (
|
||||
MoEDispatcherConfig().set_num_moe_experts(
|
||||
self.global_num_experts).set_num_local_experts(
|
||||
self.local_num_experts).set_moe_router_topk(
|
||||
top_k).set_group_topk(topk_group).
|
||||
set_num_groups(num_expert_group).set_expert_bias(
|
||||
e_score_correction_bias).set_scaling_factor(1.0).build())
|
||||
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
self.token_dispatchers = [
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
@@ -1486,7 +1231,6 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user