[main][Feature]Moe alltoallv communication optimization for unquantized RL training sence (#2088)
It comes from 0.9.1dev
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized
RL training sence & alltoallv support dpo (#1547)
- vLLM version: v0.10.0
- vLLM main:
97608dc276
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com>
Signed-off-by: taoxudonghaha <justsheldon@163.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com>
Co-authored-by: curryliu <99582471+Irving11-BKN@users.noreply.github.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: TaoYu Chen <ctynb@qq.com>
Co-authored-by: taoxudonghaha <justsheldon@163.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test.yaml
vendored
1
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -278,6 +278,7 @@ jobs:
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
|
||||
|
||||
@@ -17,3 +17,4 @@ ray>=2.47.1
|
||||
protobuf>3.20.0
|
||||
librosa
|
||||
soundfile
|
||||
pytest_mock
|
||||
@@ -157,6 +157,28 @@ def test_models_distributed_topk() -> None:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"})
|
||||
def test_models_distributed_alltoallv() -> None:
|
||||
example_prompts = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
]
|
||||
dtype = "half"
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
temperature=0.0,
|
||||
top_k=50,
|
||||
top_p=0.9)
|
||||
|
||||
with VllmRunner(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_W8A8():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
139
tests/ut/distributed/test_distributed_tensor_parallel.py
Normal file
139
tests/ut/distributed/test_distributed_tensor_parallel.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
_gather_along_first_dim, _gather_along_last_dim,
|
||||
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
|
||||
all_to_all_hp2sp, all_to_all_sp2hp)
|
||||
|
||||
|
||||
class TestDistributedCommunication(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.distributed.get_world_size", return_value=4)
|
||||
|
||||
mocker.patch("torch.distributed.get_rank", return_value=0)
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16), (8, 16)),
|
||||
(4, torch.randn(8, 16), (32, 16))])
|
||||
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
|
||||
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
|
||||
])
|
||||
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
|
||||
output_split_sizes,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
|
||||
output_split_sizes)
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
|
||||
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
|
||||
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_last_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((32, 16), (8, 16)),
|
||||
((40, 10), (10, 10)),
|
||||
])
|
||||
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_first_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((8, 16, 32), (8, 16, 8)),
|
||||
])
|
||||
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_last_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("func,input_shape,expected_shape", [
|
||||
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 128)),
|
||||
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
|
||||
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 8)),
|
||||
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
|
||||
])
|
||||
def test_wrapper_functions(self, func, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
"""test wrapper funcs"""
|
||||
mod = importlib.import_module(
|
||||
'vllm_ascend.distributed.tensor_parallel')
|
||||
globals = mod.__dict__
|
||||
test_func = globals[func]
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = test_func(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
|
||||
])
|
||||
def test_all_to_all_sp2hp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
|
||||
])
|
||||
def test_all_to_all_hp2sp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
65
tests/ut/ops/test_token_dispatcher.py
Normal file
65
tests/ut/ops/test_token_dispatcher.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
|
||||
|
||||
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
config = MoEDispatcherConfig()
|
||||
config.set_num_local_experts(2)
|
||||
config.set_num_moe_experts(4)
|
||||
config.set_moe_pad_expert_input_to_capacity(False)
|
||||
config.set_moe_expert_capacity_factor(None)
|
||||
config.set_moe_router_topk(2)
|
||||
config.set_moe_grouped_gemm(False)
|
||||
config.set_group_topk(0)
|
||||
config.set_num_groups(1)
|
||||
config.set_is_fused(False)
|
||||
return config.build()
|
||||
|
||||
def mock_ep_group(self, mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mock_group.device_group = "mock_group"
|
||||
return mock_group
|
||||
|
||||
@pytest.fixture
|
||||
def dispatcher(self, config, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
|
||||
return_value=self.mock_ep_group(mocker))
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock)
|
||||
return MoEAlltoAllSeqOverLapDispatcher(config)
|
||||
|
||||
def test_initialization(self, dispatcher, config):
|
||||
assert dispatcher.num_local_experts == config.num_local_experts
|
||||
assert dispatcher.num_experts == config.num_moe_experts
|
||||
assert dispatcher.local_expert_indices == [0, 1]
|
||||
assert dispatcher.ep_rank == 0
|
||||
assert dispatcher.ep_size == 2
|
||||
assert dispatcher.overlap_stream is not None
|
||||
@@ -18,6 +18,7 @@ class FusedMoEState(Enum):
|
||||
MC2 = 2
|
||||
AllGatherEP = 3
|
||||
NaiveMulticast = 4
|
||||
All2AllSeq = 5
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
@@ -33,6 +34,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
return FusedMoEState.NaiveMulticast
|
||||
else:
|
||||
return FusedMoEState.AllGather
|
||||
elif envs.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
|
||||
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
|
||||
return (FusedMoEState.All2AllSeq if
|
||||
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
return FusedMoEState.All2All
|
||||
|
||||
248
vllm_ascend/distributed/tensor_parallel.py
Normal file
248
vllm_ascend/distributed/tensor_parallel.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_last_dim(input_, group):
|
||||
"""Gather tensors and concatenate along the last dimension."""
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
tensor_list = output.chunk(world_size, dim=0)
|
||||
output = torch.cat(tensor_list, dim=-1).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_first_dim(input_,
|
||||
group,
|
||||
input_split_sizes=None,
|
||||
use_global_buffer=False):
|
||||
"""Reduce-scatter the input tensor across model parallel group.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): The input tensor to be reduce-scattered.
|
||||
input_split_sizes (List[int], optional): A list specifying the sizes of
|
||||
the input splits along the first dimension for each rank. If None,
|
||||
equal splitting is assumed. Default: None.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_split_sizes is None:
|
||||
dim_size = list(input_.size())
|
||||
assert (
|
||||
dim_size[0] % world_size == 0
|
||||
), "First dimension of the tensor should be divisible by tensor parallel size"
|
||||
|
||||
dim_size[0] = dim_size[0] // world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(group)
|
||||
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
|
||||
|
||||
output = torch.empty_like(input_tensor_list[rank])
|
||||
torch.distributed.reduce_scatter(output,
|
||||
input_tensor_list,
|
||||
group=group)
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_last_dim(input_, group):
|
||||
"""Reduce-scatter tensors on the last dimension."""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
target_shape = list(input_.size())
|
||||
target_shape[-1] = target_shape[-1] // world_size
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = _reduce_scatter_along_first_dim(concat_tensor,
|
||||
group).reshape(target_shape)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather_last_dim_from_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: AG, backward RS <last dim>"""
|
||||
return _gather_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def reduce_scatter_to_sequence_parallel_region(input_,
|
||||
group,
|
||||
input_split_sizes=None):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG <first dim>"""
|
||||
return _reduce_scatter_along_first_dim(input_, group, input_split_sizes)
|
||||
|
||||
|
||||
def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG: AG <last dim>"""
|
||||
return _reduce_scatter_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
|
||||
|
||||
def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None):
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input
|
||||
|
||||
input = input.contiguous()
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
output = torch.empty_like(input)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
output = input.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input.size()[1:]),
|
||||
dtype=input.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
torch.distributed.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_sp2hp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens/TP, H] to [num_tokens, H/TP].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the sequence
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
|
||||
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
tp_group = group
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = all_to_all(tp_group, concat_tensor)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_hp2sp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens, H/TP] to [num_tokens/TP, H].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the hidden
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
tp_group = group
|
||||
input_exchanged = all_to_all(tp_group, input_)
|
||||
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
|
||||
split_tensors = torch.split(
|
||||
input_reshaped,
|
||||
split_size_or_sections=input_reshaped.shape[0] // world_size,
|
||||
dim=0)
|
||||
output = torch.cat(split_tensors, dim=-1)
|
||||
return output
|
||||
@@ -154,6 +154,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# this feature is supported in A2, and eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
# Whether to enable the alltoall_seq flag, this provides a basic framework on the basis of alltoall for easy expansion.
|
||||
# 0: default, normal init.
|
||||
# 1: enable moe all2all seq.
|
||||
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
|
||||
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -40,7 +40,6 @@ def register_model():
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
||||
|
||||
else:
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
|
||||
@@ -15,8 +15,111 @@
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP, Qwen3MoeModel)
|
||||
from vllm.model_executor.models.utils import (
|
||||
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock
|
||||
from vllm_ascend.platform import VllmConfig
|
||||
|
||||
|
||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = Qwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = AscendSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
@@ -33,3 +136,20 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
62
vllm_ascend/ops/comm_utils.py
Normal file
62
vllm_ascend/ops/comm_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
a2a_out = input_.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
if event:
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
return input_, a2a_out, handle
|
||||
@@ -16,12 +16,14 @@
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -35,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
@@ -45,6 +48,8 @@ from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
@@ -273,11 +278,13 @@ def fused_experts_with_mc2(
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
@@ -299,9 +306,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
assert len(hidden_states_wrapper) == 1
|
||||
hidden_states = hidden_states_wrapper.pop()
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -329,6 +333,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -543,10 +549,7 @@ def fused_experts_with_all2all_buffer(
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states_wrapper = [hidden_states]
|
||||
del hidden_states
|
||||
|
||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||
hidden_states = apply_mlp(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
@@ -682,6 +685,24 @@ def fused_experts_moge(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_all2allv(
|
||||
token_dispatcher,
|
||||
probs,
|
||||
routing_map,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
):
|
||||
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
|
||||
(share_experts_output, dispatched_input,
|
||||
tokens_per_expert) = (token_dispatcher.token_permutation(
|
||||
hidden_states, probs, routing_map))
|
||||
|
||||
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
|
||||
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -1124,6 +1145,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
global_batch_size=self.global_batch_size,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
elif fused_moe_state == FusedMoEState.All2AllSeq:
|
||||
token_dispatcher = kwargs.get("token_dispatcher")
|
||||
return fused_experts_with_all2allv(
|
||||
token_dispatcher=token_dispatcher,
|
||||
probs=topk_weights,
|
||||
routing_map=topk_ids,
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
)
|
||||
else:
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -1275,6 +1306,25 @@ class AscendFusedMoE(FusedMoE):
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
|
||||
self.quant_method, AscendUnquantizedFusedMoEMethod):
|
||||
self.reduce_results = False
|
||||
moe_dispatcher_config = (
|
||||
MoEDispatcherConfig().set_num_moe_experts(
|
||||
self.global_num_experts).set_num_local_experts(
|
||||
self.local_num_experts).set_moe_router_topk(
|
||||
top_k).set_group_topk(topk_group).
|
||||
set_num_groups(num_expert_group).set_expert_bias(
|
||||
e_score_correction_bias).set_scaling_factor(1.0).build())
|
||||
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
self.token_dispatchers = [
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
@@ -1414,6 +1464,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
)
|
||||
@@ -1430,11 +1481,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if num_tokens < padding_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
@@ -1491,6 +1542,83 @@ class AscendFusedMoE(FusedMoE):
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance)
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AscendSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = (
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
0
vllm_ascend/ops/moe_dispatcher/__init__.py
Normal file
0
vllm_ascend/ops/moe_dispatcher/__init__.py
Normal file
453
vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Normal file
453
vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Normal file
@@ -0,0 +1,453 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
|
||||
all_to_all_sp2hp, gather_from_sequence_parallel_region,
|
||||
reduce_scatter_last_dim_to_tensor_parallel_region)
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
|
||||
|
||||
class MoEDispatcherConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.num_local_experts: int = 0
|
||||
self.num_moe_experts: int = 0
|
||||
self.moe_pad_expert_input_to_capacity: bool = False
|
||||
self.moe_expert_capacity_factor: Optional[float] = None
|
||||
self.moe_router_topk: int = 2
|
||||
self.moe_grouped_gemm: bool = False
|
||||
self.group_topk: int = 0
|
||||
self.num_groups: int = 1
|
||||
self.expert_bias: torch.Tensor = None
|
||||
self.scaling_factor: Optional[float] = None
|
||||
self.is_fused: bool = True
|
||||
|
||||
def set_num_local_experts(self, num_local_experts):
|
||||
self.num_local_experts = num_local_experts
|
||||
return self
|
||||
|
||||
def set_num_moe_experts(self, num_moe_experts):
|
||||
self.num_moe_experts = num_moe_experts
|
||||
return self
|
||||
|
||||
def set_moe_pad_expert_input_to_capacity(self,
|
||||
moe_pad_expert_input_to_capacity):
|
||||
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity
|
||||
return self
|
||||
|
||||
def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor):
|
||||
self.moe_expert_capacity_factor = moe_expert_capacity_factor
|
||||
return self
|
||||
|
||||
def set_moe_router_topk(self, moe_router_topk):
|
||||
self.moe_router_topk = moe_router_topk
|
||||
return self
|
||||
|
||||
def set_moe_grouped_gemm(self, moe_grouped_gemm):
|
||||
self.moe_grouped_gemm = moe_grouped_gemm
|
||||
return self
|
||||
|
||||
def set_group_topk(self, group_topk):
|
||||
self.group_topk = group_topk
|
||||
return self
|
||||
|
||||
def set_num_groups(self, num_groups):
|
||||
self.num_groups = num_groups
|
||||
return self
|
||||
|
||||
def set_expert_bias(self, expert_bias):
|
||||
self.expert_bias = expert_bias
|
||||
return self
|
||||
|
||||
def set_scaling_factor(self, scaling_factor):
|
||||
self.scaling_factor = scaling_factor
|
||||
return self
|
||||
|
||||
def set_is_fused(self, is_fused):
|
||||
self.is_fused = is_fused
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self
|
||||
|
||||
|
||||
class MoEDispatcher:
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.config = config
|
||||
self.shared_experts = None
|
||||
|
||||
def set_shared_experts(self, shared_experts):
|
||||
self.shared_experts = shared_experts
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@property
|
||||
def tp_ep_group(self):
|
||||
"""Get expert tensor and model parallel group."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_ep_size(self):
|
||||
return 1
|
||||
|
||||
|
||||
class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
|
||||
overlap_stream = None
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig):
|
||||
"""
|
||||
Initialize the AlltoAllSeq token dispatcher.
|
||||
|
||||
Args:
|
||||
config (MoEDispatcherConfig): Configuration for the transformer model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.num_local_experts = config.num_local_experts
|
||||
self.config = config
|
||||
# use MOEAlltoAllSEQTokenDispatcher to init
|
||||
|
||||
self.hidden_shape = None
|
||||
self.num_input_tokens = None
|
||||
self.num_experts = config.num_moe_experts
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
self.probs = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.routing_map = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# A cuda stream synchronization is needed in self.token_permutation()
|
||||
# in some cases, because there are several non-blocking DtoH data
|
||||
# transfers called in self.preprocess(). The synchronization happens
|
||||
# at different points based on MoE settings as late as possible.
|
||||
# Valid sync points are "before_permutation_1", "before_ep_alltoall",
|
||||
# "before_finish", and "no_sync".
|
||||
self.device_sync_point = "no_sync"
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.cached_permutated_local_input_tokens = None
|
||||
self.cached_global_input_tokens = None
|
||||
self.cached_shared_expert_output = None
|
||||
self.tokens_per_expert = None
|
||||
self.perm1_finish_event = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None:
|
||||
MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream()
|
||||
|
||||
self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream
|
||||
|
||||
def preprocess(self,
|
||||
indices: torch.Tensor,
|
||||
with_sync=True) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess routing map for AlltoAll communication and token permutation.
|
||||
This method computes the number of tokens assigned to each expert based on
|
||||
the routing map. It also initializes the necessary data structures for
|
||||
AlltoAll communication, such as input and output splits, and the mapping
|
||||
between global tokens and local experts.
|
||||
|
||||
Args:
|
||||
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
|
||||
[num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
|
||||
"""
|
||||
num_local_tokens_per_expert = torch.histc(indices,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
# num_local_tokens_per_expert: [num_experts]
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = indices.numel()
|
||||
if self.ep_size > 1 or self.num_local_experts > 1:
|
||||
# Token dropless and enable ep. A synchronization is needed before expert parallel
|
||||
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
|
||||
self.device_sync_point = "before_ep_alltoall"
|
||||
else:
|
||||
# Token dropless and no ep. A synchronization is needed to get the
|
||||
# `tokens_per_expert` CPU value.
|
||||
self.device_sync_point = "before_finish"
|
||||
|
||||
if ep_size > 1:
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size, self.num_local_experts).sum(axis=1).to(
|
||||
torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum."
|
||||
)
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
else:
|
||||
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
|
||||
-1, self.num_experts)
|
||||
num_tokens_per_local_expert = num_local_tokens_per_expert
|
||||
|
||||
if self.num_local_experts > 1 and with_sync:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.device_sync_point = "no_sync"
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
routing_map: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Dispatch tokens to local experts using AlltoAllSeq communication.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input token embeddings.
|
||||
probs (torch.Tensor): Probs of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
routing_map (torch.Tensor): Mapping of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- Permuted token embeddings for local experts.
|
||||
- Number of tokens per expert.
|
||||
"""
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.probs = probs
|
||||
self.top_indices = routing_map
|
||||
assert probs.dim() == 2, "Expected 2D tensor for probs"
|
||||
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
# Permutation 1: input to AlltoAll input
|
||||
def alltoall_token_permutation1(hidden_states, routing_map):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self.preprocess(routing_map)
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = all_to_all_sp2hp(hidden_states,
|
||||
group=self.tp_ep_group)
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
if self.device_sync_point == "before_permutation_1":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=self.top_indices,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1(
|
||||
hidden_states, routing_map)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
# permute 1
|
||||
|
||||
ep_group = self.ep_group
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
if self.device_sync_point == "before_ep_alltoall":
|
||||
torch.npu.current_stream().synchronize()
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
ep_group,
|
||||
)
|
||||
|
||||
# shared experts compute
|
||||
if self.shared_experts is not None:
|
||||
(share_experts_output), *_ = self.shared_experts(hidden_states)
|
||||
else:
|
||||
share_experts_output = None
|
||||
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_permutation2(global_input_tokens):
|
||||
# Permutation 2: Sort tokens by local expert.
|
||||
if self.num_local_experts > 1:
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
|
||||
# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
|
||||
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
|
||||
if self.tp_ep_size > 1 and self.config.moe_grouped_gemm:
|
||||
global_input_tokens = all_gather_last_dim_from_tensor_parallel_region(
|
||||
global_input_tokens, self.tp_ep_group)
|
||||
if self.device_sync_point == "before_finish":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
return global_input_tokens
|
||||
|
||||
# token premute2 input
|
||||
global_input_tokens = alltoall_token_permutation2(global_input_tokens)
|
||||
|
||||
return share_experts_output, global_input_tokens, tokens_per_expert
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
"""
|
||||
Reverse the token permutation to restore the original order.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Output from local experts.
|
||||
bias (torch.Tensor, optional): Bias tensor (not supported).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- Unpermuted token embeddings in the original order.
|
||||
- None (bias is not supported).
|
||||
"""
|
||||
|
||||
def alltoall_token_unpermutation1(hidden_states):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
|
||||
# Perform tensor parallel Reduce-Scatter
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region(
|
||||
hidden_states, group=self.tp_ep_group)
|
||||
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states,
|
||||
self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
hidden_states = alltoall_token_unpermutation1(hidden_states)
|
||||
|
||||
ep_group = self.ep_group
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits, ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_unpermutation2(permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.
|
||||
to(torch.int32),
|
||||
probs=self.probs,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Perform tensor parallel AlltoAll communication
|
||||
# output: [S*B, H/TP] -> [S*B/TP, H]
|
||||
if self.tp_ep_size > 1:
|
||||
output = all_to_all_hp2sp(output, self.tp_ep_group)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
|
||||
output = alltoall_token_unpermutation2(permutated_local_input_tokens)
|
||||
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
|
||||
return output, None
|
||||
Reference in New Issue
Block a user