From 6e00aed4d540432ffe2317601d47039d4ef2332b Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Sat, 2 Aug 2025 09:49:10 +0800 Subject: [PATCH] [main][Feature]Moe alltoallv communication optimization for unquantized RL training sence (#2088) It comes from 0.9.1dev [0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo (#1547) - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/97608dc276c292d9217eb6d334d969c5e89913c6 --------- Signed-off-by: weijinqian_v1 Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: curryliu <120010041@link.cuhk.edu.cn> Signed-off-by: wangli Signed-off-by: ChenTaoyu-SJTU Signed-off-by: taoxudonghaha Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: leo-pony Signed-off-by: wangxiyuan Signed-off-by: MengqingCao Co-authored-by: weijinqian_v1 Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: curryliu <99582471+Irving11-BKN@users.noreply.github.com> Co-authored-by: Li Wang Co-authored-by: TaoYu Chen Co-authored-by: taoxudonghaha Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: leo-pony Co-authored-by: wangxiyuan Co-authored-by: Mengqing Cao --- .github/workflows/vllm_ascend_test.yaml | 1 + requirements-dev.txt | 1 + .../test_offline_inference_distributed.py | 22 + .../test_distributed_tensor_parallel.py | 139 ++++++ tests/ut/ops/test_token_dispatcher.py | 65 +++ vllm_ascend/ascend_forward_context.py | 5 + vllm_ascend/distributed/tensor_parallel.py | 248 ++++++++++ vllm_ascend/envs.py | 5 + vllm_ascend/models/__init__.py | 1 - vllm_ascend/models/qwen3_moe.py | 122 ++++- vllm_ascend/ops/comm_utils.py | 62 +++ vllm_ascend/ops/fused_moe.py | 158 +++++- vllm_ascend/ops/moe_dispatcher/__init__.py | 0 .../ops/moe_dispatcher/token_dispatcher.py | 453 ++++++++++++++++++ 14 files changed, 1265 insertions(+), 17 deletions(-) create mode 100644 tests/ut/distributed/test_distributed_tensor_parallel.py create mode 100644 tests/ut/ops/test_token_dispatcher.py create mode 100644 vllm_ascend/distributed/tensor_parallel.py create mode 100644 vllm_ascend/ops/comm_utils.py create mode 100644 vllm_ascend/ops/moe_dispatcher/__init__.py create mode 100644 vllm_ascend/ops/moe_dispatcher/token_dispatcher.py diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 769fc18..32363ff 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -278,6 +278,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ diff --git a/requirements-dev.txt b/requirements-dev.txt index ed71dfe..9be7f39 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,3 +17,4 @@ ray>=2.47.1 protobuf>3.20.0 librosa soundfile +pytest_mock \ No newline at end of file diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 7ddd5c7..7d1325d 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -157,6 +157,28 @@ def test_models_distributed_topk() -> None: vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"}) +def test_models_distributed_alltoallv() -> None: + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + ] + dtype = "half" + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + def test_models_distributed_Qwen3_W8A8(): example_prompts = [ "Hello, my name is", diff --git a/tests/ut/distributed/test_distributed_tensor_parallel.py b/tests/ut/distributed/test_distributed_tensor_parallel.py new file mode 100644 index 0000000..48a88fa --- /dev/null +++ b/tests/ut/distributed/test_distributed_tensor_parallel.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import importlib + +import pytest +import torch +from pytest_mock import MockerFixture + +from tests.ut.base import PytestBase +from vllm_ascend.distributed.tensor_parallel import ( + _gather_along_first_dim, _gather_along_last_dim, + _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, + all_to_all_hp2sp, all_to_all_sp2hp) + + +class TestDistributedCommunication(PytestBase): + + @pytest.fixture(autouse=True) + def context(self, mocker: MockerFixture): + mocker.patch("torch.npu.current_device", return_value="cpu") + mocker.patch("torch.distributed.get_world_size", return_value=4) + + mocker.patch("torch.distributed.get_rank", return_value=0) + + @pytest.mark.parametrize("world_size, test_tensor, expected", + [(1, torch.randn(8, 16), (8, 16)), + (4, torch.randn(8, 16), (32, 16))]) + def test_gather_along_first_dim(self, test_tensor, expected, world_size, + mocker: MockerFixture): + """test _gather_along_first_dim""" + mocker.patch("torch.distributed.get_world_size", + return_value=world_size) + + result = _gather_along_first_dim(test_tensor, mocker.MagicMock()) + + assert result.shape == expected + + @pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [ + (torch.randn(8, 16), [5, 10, 15, 2], (32, 16)), + ]) + def test_gather_along_first_dim_unequal_split(self, test_tensor, expected, + output_split_sizes, + mocker: MockerFixture): + """test _gather_along_first_dim""" + + result = _gather_along_first_dim(test_tensor, mocker.MagicMock(), + output_split_sizes) + + assert result.shape == expected + + @pytest.mark.parametrize("world_size, test_tensor, expected", + [(1, torch.randn(8, 16, 32), (8, 16, 32)), + (4, torch.randn(8, 16, 32), (8, 16, 32 * 4))]) + def test_gather_along_last_dim(self, test_tensor, expected, world_size, + mocker: MockerFixture): + """test _gather_along_last_dim""" + mocker.patch("torch.distributed.get_world_size", + return_value=world_size) + + result = _gather_along_last_dim(test_tensor, mocker.MagicMock()) + + assert result.shape == expected + + @pytest.mark.parametrize("input_shape,expected_shape", [ + ((32, 16), (8, 16)), + ((40, 10), (10, 10)), + ]) + def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape, + mocker: MockerFixture): + input_tensor = torch.randn(*input_shape) + result = _reduce_scatter_along_first_dim(input_tensor, + mocker.MagicMock()) + assert result.shape == expected_shape + + @pytest.mark.parametrize("input_shape,expected_shape", [ + ((8, 16, 32), (8, 16, 8)), + ]) + def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape, + mocker: MockerFixture): + input_tensor = torch.randn(*input_shape) + result = _reduce_scatter_along_last_dim(input_tensor, + mocker.MagicMock()) + assert result.shape == expected_shape + + @pytest.mark.parametrize("func,input_shape,expected_shape", [ + ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), + (8, 16, 128)), + ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), + ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), + (8, 16, 8)), + ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), + ]) + def test_wrapper_functions(self, func, input_shape, expected_shape, + mocker: MockerFixture): + """test wrapper funcs""" + mod = importlib.import_module( + 'vllm_ascend.distributed.tensor_parallel') + globals = mod.__dict__ + test_func = globals[func] + input_tensor = torch.randn(*input_shape) + result = test_func(input_tensor, mocker.MagicMock()) + assert result.shape == expected_shape + + @pytest.mark.parametrize( + "input_shape,output_shape", + [ + ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] + ]) + def test_all_to_all_sp2hp(self, input_shape, output_shape, + mocker: MockerFixture): + input_tensor = torch.randn(*input_shape) + result = all_to_all_sp2hp(input_tensor, mocker.MagicMock()) + assert result.shape == output_shape + + @pytest.mark.parametrize( + "input_shape,output_shape", + [ + ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] + ]) + def test_all_to_all_hp2sp(self, input_shape, output_shape, + mocker: MockerFixture): + input_tensor = torch.randn(*input_shape) + result = all_to_all_hp2sp(input_tensor, mocker.MagicMock()) + assert result.shape == output_shape diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py new file mode 100644 index 0000000..3a42b93 --- /dev/null +++ b/tests/ut/ops/test_token_dispatcher.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import pytest +from pytest_mock import MockerFixture + +from tests.ut.base import PytestBase +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) +from vllm_ascend.utils import adapt_patch # noqa E402 + + +class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase): + + @pytest.fixture + def config(self): + config = MoEDispatcherConfig() + config.set_num_local_experts(2) + config.set_num_moe_experts(4) + config.set_moe_pad_expert_input_to_capacity(False) + config.set_moe_expert_capacity_factor(None) + config.set_moe_router_topk(2) + config.set_moe_grouped_gemm(False) + config.set_group_topk(0) + config.set_num_groups(1) + config.set_is_fused(False) + return config.build() + + def mock_ep_group(self, mocker): + mock_group = mocker.MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 2 + mock_group.device_group = "mock_group" + return mock_group + + @pytest.fixture + def dispatcher(self, config, mocker: MockerFixture): + mocker.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", + return_value=self.mock_ep_group(mocker)) + mocker.patch("torch.npu.current_device", return_value="cpu") + mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock) + return MoEAlltoAllSeqOverLapDispatcher(config) + + def test_initialization(self, dispatcher, config): + assert dispatcher.num_local_experts == config.num_local_experts + assert dispatcher.num_experts == config.num_moe_experts + assert dispatcher.local_expert_indices == [0, 1] + assert dispatcher.ep_rank == 0 + assert dispatcher.ep_size == 2 + assert dispatcher.overlap_stream is not None diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 83e4ee8..2d08079 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -18,6 +18,7 @@ class FusedMoEState(Enum): MC2 = 2 AllGatherEP = 3 NaiveMulticast = 4 + All2AllSeq = 5 # TODO(zzzzwwjj): add soc_version to choose branch @@ -33,6 +34,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool, return FusedMoEState.NaiveMulticast else: return FusedMoEState.AllGather + elif envs.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: + # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. + return (FusedMoEState.All2AllSeq if + (ep_size < 16 or with_prefill) else FusedMoEState.MC2) # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All diff --git a/vllm_ascend/distributed/tensor_parallel.py b/vllm_ascend/distributed/tensor_parallel.py new file mode 100644 index 0000000..3fff0a7 --- /dev/null +++ b/vllm_ascend/distributed/tensor_parallel.py @@ -0,0 +1,248 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py. +# This file is a part of the vllm-ascend project. +import torch + + +def _gather_along_first_dim(input_, group, output_split_sizes=None): + """Gather tensors and concatenate along the first dimension. + + Args: + input_tensor (torch.Tensor): + A tensor to be gathered. + output_split_sizes (List[int], optional): + A list specifying the sizes of the output splits along the first dimension. + If None, equal splitting is assumed. Default: None. + + Returns: + torch.Tensor: Gathered tensor. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + if output_split_sizes is None: + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + else: + dim_size[0] = sum(output_split_sizes) + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + output_tensor_list = list( + torch.split(output, output_split_sizes, dim=0)) + torch.distributed.all_gather(output_tensor_list, input_, group=group) + + return output + + +def _gather_along_last_dim(input_, group): + """Gather tensors and concatenate along the last dimension.""" + + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + tensor_list = output.chunk(world_size, dim=0) + output = torch.cat(tensor_list, dim=-1).contiguous() + + return output + + +def _reduce_scatter_along_first_dim(input_, + group, + input_split_sizes=None, + use_global_buffer=False): + """Reduce-scatter the input tensor across model parallel group. + + Args: + input_ (torch.Tensor): The input tensor to be reduce-scattered. + input_split_sizes (List[int], optional): A list specifying the sizes of + the input splits along the first dimension for each rank. If None, + equal splitting is assumed. Default: None. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + if input_split_sizes is None: + dim_size = list(input_.size()) + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.reduce_scatter_tensor(output, + input_.contiguous(), + group=group) + else: + rank = torch.distributed.get_rank(group) + input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0)) + + output = torch.empty_like(input_tensor_list[rank]) + torch.distributed.reduce_scatter(output, + input_tensor_list, + group=group) + return output + + +def _reduce_scatter_along_last_dim(input_, group): + """Reduce-scatter tensors on the last dimension.""" + world_size = torch.distributed.get_world_size(group) + target_shape = list(input_.size()) + target_shape[-1] = target_shape[-1] // world_size + input_ = input_.reshape(-1, input_.shape[-1]) + split_tensors = torch.split(input_, + split_size_or_sections=input_.shape[-1] // + world_size, + dim=1) + concat_tensor = torch.cat(split_tensors, dim=0) + output = _reduce_scatter_along_first_dim(concat_tensor, + group).reshape(target_shape) + return output + + +def all_gather_last_dim_from_tensor_parallel_region(input_, group): + """Wrapper for autograd function: forward: AG, backward RS """ + return _gather_along_last_dim(input_, group) + + +def reduce_scatter_to_sequence_parallel_region(input_, + group, + input_split_sizes=None): + """Wrapper for autograd function: forward: RS, backward AG """ + return _reduce_scatter_along_first_dim(input_, group, input_split_sizes) + + +def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group): + """Wrapper for autograd function: forward: RS, backward AG: AG """ + return _reduce_scatter_along_last_dim(input_, group) + + +def gather_from_sequence_parallel_region( + input_, + group, + output_split_sizes=None, +): + """Wrapper for autograd function: forward: AG, backward: RS """ + return _gather_along_first_dim(input_, group, output_split_sizes) + + +def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None): + world_size = torch.distributed.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input + + input = input.contiguous() + if output_split_sizes is None: + # Equal split (all2all) + output = torch.empty_like(input) + else: + # Unequal split (all2all-v) + output = input.new_empty( + size=[sum(output_split_sizes)] + list(input.size()[1:]), + dtype=input.dtype, + device=torch.npu.current_device(), + ) + torch.distributed.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + +def all_to_all_sp2hp(input_, group): + """ + Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape + [num_tokens/TP, H] to [num_tokens, H/TP]. + + Args: + input_ (torch.Tensor): + The input tensor which has been distributed along the sequence + dimension. + + Returns: + torch.Tensor: The output tensor with shape [num_tokens, H/TP]. + + """ + if group is None: + return input_ + world_size = torch.distributed.get_world_size(group=group) + tp_group = group + input_ = input_.reshape(-1, input_.shape[-1]) + split_tensors = torch.split(input_, + split_size_or_sections=input_.shape[-1] // + world_size, + dim=1) + concat_tensor = torch.cat(split_tensors, dim=0) + output = all_to_all(tp_group, concat_tensor) + return output + + +def all_to_all_hp2sp(input_, group): + """ + Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape + [num_tokens, H/TP] to [num_tokens/TP, H]. + + Args: + input_ (torch.Tensor): + The input tensor which has been distributed along the hidden + dimension. + + Returns: + torch.Tensor: The output tensor with shape [num_tokens/TP, H]. + """ + if group is None: + return input_ + world_size = torch.distributed.get_world_size(group=group) + input_ = input_.reshape(-1, input_.shape[-1]) + tp_group = group + input_exchanged = all_to_all(tp_group, input_) + input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1]) + split_tensors = torch.split( + input_reshaped, + split_size_or_sections=input_reshaped.shape[0] // world_size, + dim=0) + output = torch.cat(split_tensors, dim=-1) + return output diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index f1aa480..dee6f5a 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -154,6 +154,11 @@ env_variables: Dict[str, Callable[[], Any]] = { # this feature is supported in A2, and eager mode will get better performance. "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), + # Whether to enable the alltoall_seq flag, this provides a basic framework on the basis of alltoall for easy expansion. + # 0: default, normal init. + # 1: enable moe all2all seq. + "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": + lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 0b1b67a..f47e821 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -40,7 +40,6 @@ def register_model(): ModelRegistry.register_model( "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") - else: ModelRegistry.register_model( "DeepseekV2ForCausalLM", diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52..0c5ad39 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -15,8 +15,111 @@ # limitations under the License. # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. +from typing import Optional -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from torch import nn +from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeMLP, Qwen3MoeModel) +from vllm.model_executor.models.utils import ( + extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock +from vllm_ascend.platform import VllmConfig + + +class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = AscendSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + +@support_torch_compile +class CustomQwen3MoeModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomQwen3MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): @@ -33,3 +136,20 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomQwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/ops/comm_utils.py b/vllm_ascend/ops/comm_utils.py new file mode 100644 index 0000000..e893049 --- /dev/null +++ b/vllm_ascend/ops/comm_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +import torch +import torch.distributed +import torch.distributed as dist +import torch_npu + +COMM_STREAM = None + + +def async_all_to_all(input_, + output_split_sizes, + input_split_sizes, + group, + event=None): + if output_split_sizes is None: + # Equal split (all2all) + a2a_out = torch.empty_like(input_) + else: + # Unequal split (all2all-v) + a2a_out = input_.new_empty( + size=[sum(output_split_sizes)] + list(input_.size()[1:]), + dtype=input_.dtype, + device=torch.npu.current_device(), + ) + + if event: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + event.wait() + handle = dist.all_to_all_single( + a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + else: + handle = dist.all_to_all_single(a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + return input_, a2a_out, handle diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 80d2140..b2b1ab9 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -16,12 +16,14 @@ # Adapted from vllm/tests/kernels/test_moe.py import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -35,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.config import \ FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig @@ -45,6 +48,8 @@ from vllm_ascend.distributed.communication_op import \ data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, @@ -273,11 +278,13 @@ def fused_experts_with_mc2( return hidden_states, shared_hidden_states -def apply_mlp(hidden_states_wrapper: List[torch.Tensor], - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1) -> torch.Tensor: +def apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, +) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -299,9 +306,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], hidden_states: output hidden states after MLP. """ - assert len(hidden_states_wrapper) == 1 - hidden_states = hidden_states_wrapper.pop() - w1 = w1.transpose(1, 2) hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -329,6 +333,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], return hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. def fused_experts_with_all2all( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -543,10 +549,7 @@ def fused_experts_with_all2all_buffer( hidden_states = hidden_states[sorted_idx] group_list_type = 0 - hidden_states_wrapper = [hidden_states] - del hidden_states - - hidden_states = apply_mlp(hidden_states_wrapper, + hidden_states = apply_mlp(hidden_states, w1, w2, expert_tokens, @@ -682,6 +685,24 @@ def fused_experts_moge( return final_hidden_states +def fused_experts_with_all2allv( + token_dispatcher, + probs, + routing_map, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, +): + # Enable moe alltoallv, it's a balanced policy for precision and efficiency. + (share_experts_output, dispatched_input, + tokens_per_expert) = (token_dispatcher.token_permutation( + hidden_states, probs, routing_map)) + + expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) + output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) + return output + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1124,6 +1145,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): global_batch_size=self.global_batch_size, expert_map=expert_map, ep_group=get_ep_group()) + elif fused_moe_state == FusedMoEState.All2AllSeq: + token_dispatcher = kwargs.get("token_dispatcher") + return fused_experts_with_all2allv( + token_dispatcher=token_dispatcher, + probs=topk_weights, + routing_map=topk_ids, + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + ) else: return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, @@ -1275,6 +1306,25 @@ class AscendFusedMoE(FusedMoE): # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) + self.token_dispatcher = None + if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( + self.quant_method, AscendUnquantizedFusedMoEMethod): + self.reduce_results = False + moe_dispatcher_config = ( + MoEDispatcherConfig().set_num_moe_experts( + self.global_num_experts).set_num_local_experts( + self.local_num_experts).set_moe_router_topk( + top_k).set_group_topk(topk_group). + set_num_groups(num_expert_group).set_expert_bias( + e_score_correction_bias).set_scaling_factor(1.0).build()) + self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( + moe_dispatcher_config) + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: + token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( + moe_dispatcher_config) + self.token_dispatchers = [ + self.token_dispatcher, token_dispatcher1 + ] def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): @@ -1414,6 +1464,7 @@ class AscendFusedMoE(FusedMoE): shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, mc2_mask=mc2_mask, + token_dispatcher=self.token_dispatcher, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, ) @@ -1430,11 +1481,11 @@ class AscendFusedMoE(FusedMoE): dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + dispose_tensor(e_hidden_states) else: final_hidden_states = e_hidden_states if num_tokens < padding_size: final_hidden_states = final_hidden_states[:num_tokens] - dispose_tensor(e_hidden_states) elif self.dp_size > 1: if fused_moe_state == FusedMoEState.NaiveMulticast: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ @@ -1491,6 +1542,83 @@ class AscendFusedMoE(FusedMoE): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance) + enable_force_load_balance=enable_force_load_balance, + ) + + return hidden_states + + +class AscendSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = ( + ascend_config.torchair_graph_config.enable_multistream_moe) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + ) return hidden_states diff --git a/vllm_ascend/ops/moe_dispatcher/__init__.py b/vllm_ascend/ops/moe_dispatcher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py new file mode 100644 index 0000000..402e8fb --- /dev/null +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch_npu +from vllm.distributed.parallel_state import get_ep_group + +from vllm_ascend.distributed.tensor_parallel import ( + all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, + all_to_all_sp2hp, gather_from_sequence_parallel_region, + reduce_scatter_last_dim_to_tensor_parallel_region) +from vllm_ascend.ops.comm_utils import async_all_to_all + + +class MoEDispatcherConfig: + + def __init__(self): + self.num_local_experts: int = 0 + self.num_moe_experts: int = 0 + self.moe_pad_expert_input_to_capacity: bool = False + self.moe_expert_capacity_factor: Optional[float] = None + self.moe_router_topk: int = 2 + self.moe_grouped_gemm: bool = False + self.group_topk: int = 0 + self.num_groups: int = 1 + self.expert_bias: torch.Tensor = None + self.scaling_factor: Optional[float] = None + self.is_fused: bool = True + + def set_num_local_experts(self, num_local_experts): + self.num_local_experts = num_local_experts + return self + + def set_num_moe_experts(self, num_moe_experts): + self.num_moe_experts = num_moe_experts + return self + + def set_moe_pad_expert_input_to_capacity(self, + moe_pad_expert_input_to_capacity): + self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity + return self + + def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor): + self.moe_expert_capacity_factor = moe_expert_capacity_factor + return self + + def set_moe_router_topk(self, moe_router_topk): + self.moe_router_topk = moe_router_topk + return self + + def set_moe_grouped_gemm(self, moe_grouped_gemm): + self.moe_grouped_gemm = moe_grouped_gemm + return self + + def set_group_topk(self, group_topk): + self.group_topk = group_topk + return self + + def set_num_groups(self, num_groups): + self.num_groups = num_groups + return self + + def set_expert_bias(self, expert_bias): + self.expert_bias = expert_bias + return self + + def set_scaling_factor(self, scaling_factor): + self.scaling_factor = scaling_factor + return self + + def set_is_fused(self, is_fused): + self.is_fused = is_fused + return self + + def build(self): + return self + + +class MoEDispatcher: + + def __init__(self, config: MoEDispatcherConfig) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self.config = config + self.shared_experts = None + + def set_shared_experts(self, shared_experts): + self.shared_experts = shared_experts + + @property + def ep_group(self): + """Get expert model parallel group.""" + return get_ep_group().device_group + + @property + def ep_rank(self): + return get_ep_group().rank_in_group + + @property + def ep_size(self): + return get_ep_group().world_size + + @property + def tp_ep_group(self): + """Get expert tensor and model parallel group.""" + return None + + @property + def tp_ep_size(self): + return 1 + + +class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): + overlap_stream = None + """ + The implementation of the AlltoAll-based token dispatcher, which handles token + dispatching on the sequence level instead of token level. The core of this implementation + lies in each device dispatching on the entire sequence, with the hidden state being partitioned. + + """ + + def __init__(self, config: MoEDispatcherConfig): + """ + Initialize the AlltoAllSeq token dispatcher. + + Args: + config (MoEDispatcherConfig): Configuration for the transformer model. + """ + super().__init__(config) + self.num_local_experts = config.num_local_experts + self.config = config + # use MOEAlltoAllSEQTokenDispatcher to init + + self.hidden_shape = None + self.num_input_tokens = None + self.num_experts = config.num_moe_experts + assert self.num_local_experts > 0, "Expected at least one expert" + if self.num_local_experts > 1: + self.expert_ids_per_ep_rank = torch.tensor( + [i % self.num_local_experts for i in range(self.num_experts)], + dtype=torch.int32, + device=torch.npu.current_device(), + ) + + local_expert_indices_offset = (self.ep_rank * self.num_local_experts) + + self.local_expert_indices = [ + local_expert_indices_offset + i + for i in range(self.num_local_experts) + ] + assert (len(self.local_expert_indices) == self.num_local_experts + ), "Invalid local expert indices" + for i in range(len(self.local_expert_indices) - 1): + assert (self.local_expert_indices[i] == + self.local_expert_indices[i + 1] - + 1), "local_expert_indices must be continuous" + self.probs = None + self.input_splits = None + self.output_splits = None + self.routing_map = None + self.hidden_shape_before_permute = None + + # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert_cpu = None + self.num_global_tokens_per_local_expert = None + + # A cuda stream synchronization is needed in self.token_permutation() + # in some cases, because there are several non-blocking DtoH data + # transfers called in self.preprocess(). The synchronization happens + # at different points based on MoE settings as late as possible. + # Valid sync points are "before_permutation_1", "before_ep_alltoall", + # "before_finish", and "no_sync". + self.device_sync_point = "no_sync" + + # cached intermediate tensors. + self.cached_permutated_local_input_tokens = None + self.cached_global_input_tokens = None + self.cached_shared_expert_output = None + self.tokens_per_expert = None + self.perm1_finish_event = None + self.global_input_tokens_local_experts_indices = None + + if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None: + MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream() + + self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream + + def preprocess(self, + indices: torch.Tensor, + with_sync=True) -> torch.Tensor: + """ + Preprocess routing map for AlltoAll communication and token permutation. + This method computes the number of tokens assigned to each expert based on + the routing map. It also initializes the necessary data structures for + AlltoAll communication, such as input and output splits, and the mapping + between global tokens and local experts. + + Args: + routing_map (torch.Tensor): The mapping of tokens to experts, with shape + [num_tokens, num_experts]. + + Returns: + torch.Tensor: Tensor containing the number of tokens assigned to local expert. + """ + num_local_tokens_per_expert = torch.histc(indices, + bins=self.num_experts, + min=0, + max=self.num_experts) + + # num_local_tokens_per_expert: [num_experts] + + ep_size = self.ep_size + + # Dropless + self.num_out_tokens = indices.numel() + if self.ep_size > 1 or self.num_local_experts > 1: + # Token dropless and enable ep. A synchronization is needed before expert parallel + # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. + self.device_sync_point = "before_ep_alltoall" + else: + # Token dropless and no ep. A synchronization is needed to get the + # `tokens_per_expert` CPU value. + self.device_sync_point = "before_finish" + + if ep_size > 1: + # =================================================== + # Calculate input_splits, output_splits for alltoall-v. + # =================================================== + self.input_splits = (num_local_tokens_per_expert.reshape( + ep_size, self.num_local_experts).sum(axis=1).to( + torch.device("cpu"), non_blocking=True).numpy()) + num_global_tokens_per_expert = gather_from_sequence_parallel_region( + num_local_tokens_per_expert, + group=self.ep_group).reshape(ep_size, self.num_experts) + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + 0]:self.local_expert_indices[-1] + 1] + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before sum." + ) + self.output_splits = (self.num_global_tokens_per_local_expert.sum( + axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( + axis=0) + # =================================================== + # num_global_tokens_per_expert: [ep_size, num_experts] + # num_global_tokens_per_local_expert: [ep_size, num_local_experts] + # num_tokens_per_local_expert: [num_local_experts] + # =================================================== + else: + self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( + -1, self.num_experts) + num_tokens_per_local_expert = num_local_tokens_per_expert + + if self.num_local_experts > 1 and with_sync: + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) + self.device_sync_point = "no_sync" + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) + + return num_tokens_per_local_expert + + def token_permutation( + self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + ): + """ + Dispatch tokens to local experts using AlltoAllSeq communication. + + Args: + hidden_states (torch.Tensor): Input token embeddings. + probs (torch.Tensor): Probs of tokens assigned to experts. + Shape: [num_tokens, num_experts]. + routing_map (torch.Tensor): Mapping of tokens assigned to experts. + Shape: [num_tokens, num_experts]. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Permuted token embeddings for local experts. + - Number of tokens per expert. + """ + self.hidden_shape = hidden_states.shape + self.probs = probs + self.top_indices = routing_map + assert probs.dim() == 2, "Expected 2D tensor for probs" + assert routing_map.dim() == 2, "Expected 2D tensor for routing map" + + # Permutation 1: input to AlltoAll input + def alltoall_token_permutation1(hidden_states, routing_map): + assert self.hidden_shape is not None + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self.preprocess(routing_map) + if self.tp_ep_size > 1: + hidden_states = all_to_all_sp2hp(hidden_states, + group=self.tp_ep_group) + self.hidden_shape_before_permute = hidden_states.shape + + if self.device_sync_point == "before_permutation_1": + torch.npu.current_stream().synchronize() + + permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) + return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert + + permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1( + hidden_states, routing_map) + self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping + # permute 1 + + ep_group = self.ep_group + + # Perform expert parallel AlltoAll communication + if self.device_sync_point == "before_ep_alltoall": + torch.npu.current_stream().synchronize() + _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + ep_group, + ) + + # shared experts compute + if self.shared_experts is not None: + (share_experts_output), *_ = self.shared_experts(hidden_states) + else: + share_experts_output = None + + permute1_ep_all_to_all_handle.wait() + permutated_local_input_tokens.untyped_storage().resize_(0) + + def alltoall_token_permutation2(global_input_tokens): + # Permutation 2: Sort tokens by local expert. + if self.num_local_experts > 1: + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, + self.global_input_tokens_local_experts_indices) + + # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. + # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] + if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: + global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( + global_input_tokens, self.tp_ep_group) + if self.device_sync_point == "before_finish": + torch.npu.current_stream().synchronize() + + return global_input_tokens + + # token premute2 input + global_input_tokens = alltoall_token_permutation2(global_input_tokens) + + return share_experts_output, global_input_tokens, tokens_per_expert + + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + """ + Reverse the token permutation to restore the original order. + + Args: + hidden_states (torch.Tensor): Output from local experts. + bias (torch.Tensor, optional): Bias tensor (not supported). + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Unpermuted token embeddings in the original order. + - None (bias is not supported). + """ + + def alltoall_token_unpermutation1(hidden_states): + assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" + # Perform tensor parallel Reduce-Scatter + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + if self.tp_ep_size > 1: + hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region( + hidden_states, group=self.tp_ep_group) + + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, + self.reversed_global_input_permutation_mapping) + + return hidden_states + + hidden_states = alltoall_token_unpermutation1(hidden_states) + + ep_group = self.ep_group + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, permutated_local_input_tokens, handle = async_all_to_all( + hidden_states, self.input_splits, self.output_splits, ep_group) + handle.wait() + hidden_states.untyped_storage().resize_(0) + + def alltoall_token_unpermutation2(permutated_local_input_tokens): + # Unpermutation 1: AlltoAll output to output + + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=permutated_local_input_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping. + to(torch.int32), + probs=self.probs, + restore_shape=self.hidden_shape_before_permute) + + # Perform tensor parallel AlltoAll communication + # output: [S*B, H/TP] -> [S*B/TP, H] + if self.tp_ep_size > 1: + output = all_to_all_hp2sp(output, self.tp_ep_group) + + # Reshape the output tensor + output = output.view(self.hidden_shape) + return output + + output = alltoall_token_unpermutation2(permutated_local_input_tokens) + + self.input_splits = None + self.output_splits = None + self.num_global_tokens_per_local_expert = None + self.num_global_tokens_per_local_expert_cpu = None + + return output, None