[2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.
The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test case fixed.
- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -18,29 +18,30 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from vllm import forward_context
|
|
||||||
|
|
||||||
from vllm_ascend.distributed import moe_comm_method
|
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
|
||||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
FusedMoEConfig, FusedMoEParallelConfig)
|
||||||
NativeAllGatherCommImpl)
|
|
||||||
|
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
|
||||||
|
AllGatherCommImpl, NativeAllGatherCommImpl)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", [16, 128])
|
@pytest.mark.parametrize("num_tokens", [16, 128])
|
||||||
@pytest.mark.parametrize("hidden_size", [64, 128])
|
@pytest.mark.parametrize("hidden_size", [64, 128])
|
||||||
@pytest.mark.parametrize("global_num_experts", [8, 16])
|
@pytest.mark.parametrize("global_num_experts", [8, 16])
|
||||||
|
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
||||||
@pytest.mark.parametrize("top_k_num", [2, 4])
|
@pytest.mark.parametrize("top_k_num", [2, 4])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
|
||||||
@pytest.mark.parametrize("ep_rank", [0, 1])
|
@pytest.mark.parametrize("ep_rank", [0, 1])
|
||||||
def test_all_gather_comm_impl(
|
def test_all_gather_comm_impl(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
global_num_experts,
|
global_num_experts,
|
||||||
|
num_local_experts,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
dtype,
|
dtype,
|
||||||
num_local_experts,
|
|
||||||
ep_rank,
|
ep_rank,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
|
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
|
||||||
@@ -56,23 +57,37 @@ def test_all_gather_comm_impl(
|
|||||||
"num_local_experts cannot be greater than global_num_experts")
|
"num_local_experts cannot be greater than global_num_experts")
|
||||||
|
|
||||||
device = torch.device("npu")
|
device = torch.device("npu")
|
||||||
hf_config = PretrainedConfig(
|
|
||||||
num_experts_per_tok=top_k_num,
|
# mock get_tensor_model_parallel_rank to return ep_rank
|
||||||
|
mocker.patch(
|
||||||
|
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
|
||||||
|
return_value=ep_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
# make moe config
|
||||||
|
parallel_config = SimpleNamespace(
|
||||||
|
enable_expert_parallel=num_local_experts < global_num_experts)
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||||
|
tp_size_=max(2, global_num_experts // num_local_experts),
|
||||||
|
dp_size_=1,
|
||||||
|
vllm_parallel_config=parallel_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_config = FusedMoEConfig(
|
||||||
num_experts=global_num_experts,
|
num_experts=global_num_experts,
|
||||||
|
experts_per_token=top_k_num,
|
||||||
|
hidden_dim=hidden_size,
|
||||||
|
num_local_experts=num_local_experts,
|
||||||
|
moe_parallel_config=moe_parallel_config,
|
||||||
|
in_dtype=dtype,
|
||||||
|
quant_config=None, # No quantization in this test
|
||||||
|
max_num_tokens=num_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Instantiate implementations
|
# Instantiate implementations
|
||||||
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
|
native_impl = NativeAllGatherCommImpl(moe_config)
|
||||||
|
|
||||||
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
|
all_gather_impl = AllGatherCommImpl(moe_config)
|
||||||
|
|
||||||
# TODO: Find out if this is the correct way to mock the forward context and ep group
|
|
||||||
# Mock get_forward_context to return an object with moe_comm_method
|
|
||||||
forward_context._forward_context = SimpleNamespace(
|
|
||||||
moe_comm_method=all_gather_impl)
|
|
||||||
# Mock get_ep_group to return a fake group with the specified ep_rank
|
|
||||||
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
|
|
||||||
moe_comm_method.get_ep_group = lambda: fake_ep_group
|
|
||||||
|
|
||||||
# --- Input Data ---
|
# --- Input Data ---
|
||||||
hidden_states = torch.randn(num_tokens,
|
hidden_states = torch.randn(num_tokens,
|
||||||
@@ -103,11 +118,11 @@ def test_all_gather_comm_impl(
|
|||||||
native_permuted_hidden,
|
native_permuted_hidden,
|
||||||
native_expert_tokens,
|
native_expert_tokens,
|
||||||
_,
|
_,
|
||||||
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
|
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
||||||
expert_map, num_experts)
|
num_experts)
|
||||||
# Simulate MLP output
|
# Simulate MLP output
|
||||||
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
||||||
native_impl._post_process(native_mlp_output, native_hidden_states_out)
|
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
||||||
|
|
||||||
# --- Run AllGather Implementation ---
|
# --- Run AllGather Implementation ---
|
||||||
all_gather_hidden_states_out = hidden_states.clone()
|
all_gather_hidden_states_out = hidden_states.clone()
|
||||||
@@ -115,15 +130,14 @@ def test_all_gather_comm_impl(
|
|||||||
all_gather_permuted_hidden,
|
all_gather_permuted_hidden,
|
||||||
all_gather_expert_tokens,
|
all_gather_expert_tokens,
|
||||||
_,
|
_,
|
||||||
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
|
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
||||||
topk_weights, expert_map,
|
expert_map, num_experts)
|
||||||
num_experts)
|
|
||||||
|
|
||||||
# Use the same simulated MLP output for a fair comparison
|
# Use the same simulated MLP output for a fair comparison
|
||||||
all_gather_mlp_output = native_mlp_output.clone()
|
all_gather_mlp_output = native_mlp_output.clone()
|
||||||
|
|
||||||
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
|
all_gather_impl.unpermute(all_gather_mlp_output,
|
||||||
all_gather_hidden_states_out)
|
all_gather_hidden_states_out)
|
||||||
|
|
||||||
# --- Assertions ---
|
# --- Assertions ---
|
||||||
# Define tolerance based on dtype
|
# Define tolerance based on dtype
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -87,69 +87,3 @@ class TestNPUCommunicator(unittest.TestCase):
|
|||||||
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
|
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
|
||||||
|
|
||||||
assert output.tolist() == [[10, 20], [50, 60]]
|
assert output.tolist() == [[10, 20], [50, 60]]
|
||||||
|
|
||||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
|
||||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
|
||||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
|
||||||
@patch("torch.distributed.get_process_group_ranks",
|
|
||||||
return_value={
|
|
||||||
0: 0,
|
|
||||||
1: 1
|
|
||||||
})
|
|
||||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
|
||||||
@patch("torch.distributed.is_initialized", return_value=True)
|
|
||||||
@patch("torch.distributed.get_rank", return_value=1)
|
|
||||||
@patch("torch.distributed.is_initialized", return_value=True)
|
|
||||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
|
||||||
@patch("torch.distributed.get_rank", return_value=1)
|
|
||||||
@patch("torch.distributed.get_world_size", return_value=2)
|
|
||||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
|
||||||
@patch("torch.npu.device")
|
|
||||||
def test_dispatch(self, *_):
|
|
||||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
|
||||||
comm.all2all_manager = Mock()
|
|
||||||
hidden_states = torch.randn(2, 4, 8)
|
|
||||||
router_logits = torch.randn(2, 4, 2)
|
|
||||||
|
|
||||||
mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))
|
|
||||||
comm.all2all_manager.dispatch.return_value = mock_dispatch_result
|
|
||||||
|
|
||||||
result_hidden, result_logits = comm.dispatch(hidden_states,
|
|
||||||
router_logits)
|
|
||||||
|
|
||||||
assert torch.allclose(result_hidden, mock_dispatch_result[0])
|
|
||||||
assert torch.allclose(result_logits, mock_dispatch_result[1])
|
|
||||||
|
|
||||||
comm.all2all_manager.dispatch.assert_called_once_with(
|
|
||||||
hidden_states, router_logits)
|
|
||||||
|
|
||||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
|
||||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
|
||||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
|
||||||
@patch("torch.distributed.get_process_group_ranks",
|
|
||||||
return_value={
|
|
||||||
0: 0,
|
|
||||||
1: 1
|
|
||||||
})
|
|
||||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
|
||||||
@patch("torch.distributed.is_initialized", return_value=True)
|
|
||||||
@patch("torch.distributed.get_rank", return_value=1)
|
|
||||||
@patch("torch.distributed.is_initialized", return_value=True)
|
|
||||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
|
||||||
@patch("torch.distributed.get_rank", return_value=1)
|
|
||||||
@patch("torch.distributed.get_world_size", return_value=2)
|
|
||||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
|
||||||
@patch("torch.npu.device")
|
|
||||||
def test_combine(self, *_):
|
|
||||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
|
||||||
comm.all2all_manager = Mock()
|
|
||||||
hidden_states = torch.randn(2, 4, 8)
|
|
||||||
|
|
||||||
mock_combine_result = torch.randn(2, 4, 8)
|
|
||||||
comm.all2all_manager.combine.return_value = mock_combine_result
|
|
||||||
|
|
||||||
result = comm.combine(hidden_states)
|
|
||||||
|
|
||||||
assert torch.allclose(result, mock_combine_result)
|
|
||||||
|
|
||||||
comm.all2all_manager.combine.assert_called_once_with(hidden_states)
|
|
||||||
|
|||||||
@@ -289,13 +289,13 @@ class TestUtils(TestBase):
|
|||||||
# ascend custom op is not registered
|
# ascend custom op is not registered
|
||||||
utils.register_ascend_customop()
|
utils.register_ascend_customop()
|
||||||
# should call register_oot three
|
# should call register_oot three
|
||||||
self.assertEqual(mock_customop.register_oot.call_count, 8)
|
self.assertEqual(mock_customop.register_oot.call_count, 9)
|
||||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||||
|
|
||||||
# ascend custom op is already registered
|
# ascend custom op is already registered
|
||||||
utils.register_ascend_customop()
|
utils.register_ascend_customop()
|
||||||
# should not register_oot again, thus only called three in this ut
|
# should not register_oot again, thus only called three in this ut
|
||||||
self.assertEqual(mock_customop.register_oot.call_count, 8)
|
self.assertEqual(mock_customop.register_oot.call_count, 9)
|
||||||
|
|
||||||
|
|
||||||
class TestProfileExecuteDuration(TestBase):
|
class TestProfileExecuteDuration(TestBase):
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
|||||||
set_forward_context)
|
set_forward_context)
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEState(Enum):
|
class FusedMoEState(Enum):
|
||||||
@@ -57,7 +56,7 @@ def set_ascend_forward_context(
|
|||||||
with_prefill: bool = True,
|
with_prefill: bool = True,
|
||||||
in_profile_run: bool = False,
|
in_profile_run: bool = False,
|
||||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||||
moe_comm_method: Optional[MoECommMethod] = None,
|
moe_comm_method: str = "",
|
||||||
num_actual_tokens: Optional[int] = None,
|
num_actual_tokens: Optional[int] = None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||||
@@ -75,7 +74,7 @@ def set_ascend_forward_context(
|
|||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
):
|
):
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
forward_context.moe_comm_method = moe_comm_method
|
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
||||||
forward_context.with_prefill = with_prefill
|
forward_context.with_prefill = with_prefill
|
||||||
ep_size = (get_ep_group().world_size if
|
ep_size = (get_ep_group().world_size if
|
||||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||||
DeviceCommunicatorBase
|
DeviceCommunicatorBase
|
||||||
from vllm.utils import logger
|
|
||||||
|
|
||||||
|
|
||||||
class NPUCommunicator(DeviceCommunicatorBase):
|
class NPUCommunicator(DeviceCommunicatorBase):
|
||||||
@@ -35,12 +34,6 @@ class NPUCommunicator(DeviceCommunicatorBase):
|
|||||||
# init device according to rank
|
# init device according to rank
|
||||||
self.device = torch.npu.current_device()
|
self.device = torch.npu.current_device()
|
||||||
|
|
||||||
if self.use_all2all:
|
|
||||||
from vllm.distributed.device_communicators.all2all import \
|
|
||||||
NaiveAll2AllManager
|
|
||||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
|
||||||
logger.info("Using naive all2all manager.")
|
|
||||||
|
|
||||||
def all_to_all(self,
|
def all_to_all(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
scatter_dim: int = 0,
|
scatter_dim: int = 0,
|
||||||
@@ -80,17 +73,3 @@ class NPUCommunicator(DeviceCommunicatorBase):
|
|||||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||||
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
# TODO: Add ut for dispatch and combine
|
|
||||||
def dispatch(
|
|
||||||
self, hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert self.all2all_manager is not None
|
|
||||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
|
||||||
hidden_states, router_logits)
|
|
||||||
return hidden_states, router_logits
|
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert self.all2all_manager is not None
|
|
||||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
|
from vllm.distributed.parallel_state import (
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.communication_op import \
|
||||||
|
data_parallel_reduce_scatter
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||||
|
|
||||||
@@ -14,26 +20,34 @@ from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
|||||||
class MoECommMethod(ABC):
|
class MoECommMethod(ABC):
|
||||||
"""Base class for MoE communication methods."""
|
"""Base class for MoE communication methods."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
self,
|
self.moe_config = moe_config
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
hf_config: PretrainedConfig,
|
|
||||||
):
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
|
||||||
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
|
|
||||||
# global_num_experts may be called num_experts or n_routed_experts in different models.
|
|
||||||
possible_keys = ["num_experts", "n_routed_experts"]
|
|
||||||
for key in possible_keys:
|
|
||||||
if hasattr(hf_config, key):
|
|
||||||
self.global_num_experts = getattr(hf_config, key)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.global_num_experts = 0
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _pre_process(
|
def prepare(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Prepare the MoE communication method.
|
||||||
|
|
||||||
|
This method is called before quant_method.apply to prepare the
|
||||||
|
communication method. It can be used to initialize any necessary
|
||||||
|
resources or configurations.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def finalize(self, hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool) -> torch.Tensor:
|
||||||
|
"""Finalize the MoE communication method.
|
||||||
|
|
||||||
|
This method is called after quant_method.apply to finalize the
|
||||||
|
communication method. It can be used to clean up any resources or
|
||||||
|
configurations.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def permute(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
@@ -67,8 +81,8 @@ class MoECommMethod(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _post_process(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
"""Post-process after MLP.
|
"""Post-process after MLP.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -82,7 +96,18 @@ class MoECommMethod(ABC):
|
|||||||
|
|
||||||
class DummyCommImpl(MoECommMethod):
|
class DummyCommImpl(MoECommMethod):
|
||||||
|
|
||||||
def _pre_process(
|
def prepare(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Dummy prepare method that does nothing."""
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def finalize(self, hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool) -> torch.Tensor:
|
||||||
|
"""Dummy finalize method that does nothing."""
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def permute(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
@@ -90,92 +115,20 @@ class DummyCommImpl(MoECommMethod):
|
|||||||
expert_map: torch.Tensor,
|
expert_map: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
"""Dummy implementation, see moe_comm_pre_process_fake for details."""
|
"""Dummy implementation, make sure the output shapes are correct."""
|
||||||
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
|
top_k_num = topk_ids.shape[1]
|
||||||
expert_map, num_experts)
|
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num,
|
||||||
|
dim=0)
|
||||||
def _post_process(self, mlp_output: torch.Tensor,
|
expert_tokens = torch.zeros((num_experts, ),
|
||||||
hidden_states: torch.Tensor) -> None:
|
dtype=torch.int64,
|
||||||
"""Dummy implementation that does nothing."""
|
device=hidden_states.device)
|
||||||
pass
|
group_list_type = 0
|
||||||
|
|
||||||
|
|
||||||
class NativeAllGatherCommImpl(MoECommMethod):
|
|
||||||
"""This implementation should be compatible with all scenarios.
|
|
||||||
|
|
||||||
Note that this implementation purely consists of native PyTorch ops
|
|
||||||
and does not use any NPU-specific ops. So the performance may not be optimal.
|
|
||||||
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _pre_process(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
||||||
num_tokens = hidden_states.shape[0]
|
|
||||||
|
|
||||||
# Generate token indices and flatten
|
|
||||||
token_indices = torch.arange(num_tokens,
|
|
||||||
device=self.device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
token_indices = (token_indices.unsqueeze(1).expand(
|
|
||||||
-1, self.top_k_num).reshape(-1))
|
|
||||||
|
|
||||||
# Flatten token-to-expert mappings and map to local experts
|
|
||||||
weights_flat = topk_weights.view(-1)
|
|
||||||
experts_flat = topk_ids.view(-1)
|
|
||||||
local_experts_flat = (expert_map[experts_flat]
|
|
||||||
if expert_map is not None else experts_flat)
|
|
||||||
|
|
||||||
# Filter valid token-expert pairs
|
|
||||||
mask = local_experts_flat != -1
|
|
||||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
||||||
# So we need to filter out invalid tokens by zeroing their weights.
|
|
||||||
# This is a workaround and should be removed after the issue is fixed
|
|
||||||
filtered_weights = torch.where(mask, weights_flat,
|
|
||||||
torch.zeros_like(weights_flat)).to(
|
|
||||||
self.dtype)
|
|
||||||
filtered_experts = torch.where(
|
|
||||||
mask,
|
|
||||||
local_experts_flat,
|
|
||||||
torch.full_like(local_experts_flat, num_experts),
|
|
||||||
).to(topk_ids.dtype)
|
|
||||||
|
|
||||||
# Sort by local expert IDs
|
|
||||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
|
||||||
self.sorted_token_indices = token_indices[sort_indices]
|
|
||||||
self.sorted_weights = filtered_weights[sort_indices]
|
|
||||||
|
|
||||||
# Compute token counts with minlength of num_experts
|
|
||||||
# This is equivalent to but faster than:
|
|
||||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
||||||
token_counts = torch.zeros(num_experts + 1,
|
|
||||||
device=self.device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
||||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
|
||||||
expert_tokens = token_counts[:num_experts]
|
|
||||||
|
|
||||||
# Rearrange hidden_states
|
|
||||||
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
|
||||||
|
|
||||||
group_list_type = 1 # `count` mode
|
|
||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||||||
|
|
||||||
def _post_process(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
"""Dummy implementation that does nothing."""
|
||||||
|
pass
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
|
||||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
|
||||||
mlp_output)
|
|
||||||
|
|
||||||
hidden_states[:] = final_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class AllGatherCommImpl(MoECommMethod):
|
class AllGatherCommImpl(MoECommMethod):
|
||||||
@@ -197,7 +150,46 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
This is a workaround and should be removed after the issue is fixed.
|
This is a workaround and should be removed after the issue is fixed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _pre_process(
|
def prepare(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""When DP size > 1, pad the hidden states and router logits for communication."""
|
||||||
|
if self.moe_config.dp_size > 1:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||||
|
|
||||||
|
self.num_tokens = hidden_states.shape[0]
|
||||||
|
pad_size = max_tokens_across_dp - self.num_tokens
|
||||||
|
if pad_size > 0:
|
||||||
|
hidden_states = nn.functional.pad(hidden_states,
|
||||||
|
(0, 0, 0, pad_size))
|
||||||
|
router_logits = nn.functional.pad(router_logits,
|
||||||
|
(0, 0, 0, pad_size))
|
||||||
|
|
||||||
|
hidden_states = self.moe_config.dp_group.all_gather(
|
||||||
|
hidden_states, 0)
|
||||||
|
router_logits = self.moe_config.dp_group.all_gather(
|
||||||
|
router_logits, 0)
|
||||||
|
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def finalize(self, hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool) -> torch.Tensor:
|
||||||
|
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
|
||||||
|
|
||||||
|
When TP size > 1, all-reduce the hidden states to get the final output.
|
||||||
|
"""
|
||||||
|
if self.moe_config.dp_size > 1:
|
||||||
|
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
|
||||||
|
hidden_states = hidden_states[:self.num_tokens]
|
||||||
|
|
||||||
|
if reduce_results and (self.moe_config.tp_size > 1
|
||||||
|
or self.moe_config.ep_size > 1):
|
||||||
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def permute(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
@@ -220,15 +212,15 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
||||||
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
||||||
|
|
||||||
first_expert_idx = get_ep_group().rank_in_group * num_experts
|
first_expert_idx = self.moe_config.ep_rank * num_experts
|
||||||
last_expert_idx = first_expert_idx + num_experts
|
last_expert_idx = first_expert_idx + num_experts
|
||||||
|
|
||||||
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
||||||
torch_npu.npu_moe_init_routing_v2(
|
torch_npu.npu_moe_init_routing_v2(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
active_num=num_tokens * self.top_k_num,
|
active_num=num_tokens * self.moe_config.experts_per_token,
|
||||||
expert_num=self.global_num_experts,
|
expert_num=self.moe_config.num_experts,
|
||||||
expert_tokens_num_type=1, # Only support `count` mode now
|
expert_tokens_num_type=1, # Only support `count` mode now
|
||||||
expert_tokens_num_flag=True, # Output `expert_tokens`
|
expert_tokens_num_flag=True, # Output `expert_tokens`
|
||||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||||
@@ -241,14 +233,92 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||||||
|
|
||||||
def _post_process(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
||||||
permuted_tokens=mlp_output,
|
permuted_tokens=mlp_output,
|
||||||
sorted_indices=self.expanded_row_idx,
|
sorted_indices=self.expanded_row_idx,
|
||||||
probs=self.topk_weights)
|
probs=self.topk_weights)
|
||||||
|
|
||||||
|
|
||||||
|
class NativeAllGatherCommImpl(AllGatherCommImpl):
|
||||||
|
"""This implementation should be compatible with all scenarios.
|
||||||
|
|
||||||
|
Note that this implementation purely consists of native PyTorch ops
|
||||||
|
and does not use any NPU-specific ops. So the performance may not be optimal.
|
||||||
|
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def permute(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
expert_map: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
|
||||||
|
# Generate token indices and flatten
|
||||||
|
token_indices = torch.arange(num_tokens,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
token_indices = (token_indices.unsqueeze(1).expand(
|
||||||
|
-1, self.moe_config.experts_per_token).reshape(-1))
|
||||||
|
|
||||||
|
# Flatten token-to-expert mappings and map to local experts
|
||||||
|
weights_flat = topk_weights.view(-1)
|
||||||
|
experts_flat = topk_ids.view(-1)
|
||||||
|
local_experts_flat = (expert_map[experts_flat]
|
||||||
|
if expert_map is not None else experts_flat)
|
||||||
|
|
||||||
|
# Filter valid token-expert pairs
|
||||||
|
mask = local_experts_flat != -1
|
||||||
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||||
|
# So we need to filter out invalid tokens by zeroing their weights.
|
||||||
|
# This is a workaround and should be removed after the issue is fixed
|
||||||
|
filtered_weights = torch.where(mask, weights_flat,
|
||||||
|
torch.zeros_like(weights_flat)).to(
|
||||||
|
topk_weights.dtype)
|
||||||
|
filtered_experts = torch.where(
|
||||||
|
mask,
|
||||||
|
local_experts_flat,
|
||||||
|
torch.full_like(local_experts_flat, num_experts),
|
||||||
|
).to(topk_ids.dtype)
|
||||||
|
|
||||||
|
# Sort by local expert IDs
|
||||||
|
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||||
|
self.sorted_token_indices = token_indices[sort_indices]
|
||||||
|
self.sorted_weights = filtered_weights[sort_indices]
|
||||||
|
|
||||||
|
# Compute token counts with minlength of num_experts
|
||||||
|
# This is equivalent to but faster than:
|
||||||
|
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||||
|
token_counts = torch.zeros(num_experts + 1,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||||
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||||
|
expert_tokens = token_counts[:num_experts]
|
||||||
|
|
||||||
|
# Rearrange hidden_states
|
||||||
|
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||||
|
|
||||||
|
group_list_type = 1 # `count` mode
|
||||||
|
|
||||||
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||||||
|
|
||||||
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor) -> None:
|
||||||
|
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
||||||
|
|
||||||
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
|
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||||
|
mlp_output)
|
||||||
|
|
||||||
|
hidden_states[:] = final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MC2CommImpl(MoECommMethod):
|
class MC2CommImpl(MoECommMethod):
|
||||||
"""This implementation is for the scenarios listed below:
|
"""This implementation is for the scenarios listed below:
|
||||||
1. `enable_expert_parallel=True`.
|
1. `enable_expert_parallel=True`.
|
||||||
@@ -259,40 +329,83 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
Communication and Computation parallelism on Ascend devices.
|
Communication and Computation parallelism on Ascend devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||||
self,
|
super().__init__(moe_config)
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
hf_config: PretrainedConfig,
|
|
||||||
):
|
|
||||||
super().__init__(device, dtype, hf_config)
|
|
||||||
|
|
||||||
# Shared communication configurations
|
# NOTE: We do not need to use mc2_group's rank and world size
|
||||||
ep_group = get_mc2_group()
|
# because ep_group and mc2_group basically have the same init params.
|
||||||
self.ep_rank_id = ep_group.rank_in_group
|
# We only init another group because of the restriction of MC2:
|
||||||
self.ep_world_size = ep_group.world_size
|
# "No other groups can be used in the same process as the MC2 group."
|
||||||
self.tp_world_size = get_tp_group().world_size
|
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
|
||||||
|
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
|
||||||
device_group = ep_group.device_group
|
|
||||||
local_rank = torch.distributed.get_rank(group=device_group)
|
|
||||||
backend = device_group._get_backend(torch.device("npu"))
|
|
||||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
|
||||||
|
|
||||||
# Feature flags
|
# Feature flags
|
||||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||||
"npu_moe_distribute_dispatch_v2")
|
"npu_moe_distribute_dispatch_v2")
|
||||||
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
||||||
self.need_extra_args = self.is_ascend_a3 # or is_torchair
|
self.need_extra_args = self.is_ascend_a3
|
||||||
|
self._restore_tp_across_dp()
|
||||||
|
|
||||||
# Intermediate tensors to be passed from pre_process to post_process
|
def _restore_tp_across_dp(self):
|
||||||
self.topk_ids = None
|
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||||
self.topk_weights = None
|
# tp_size and tp_rank.
|
||||||
self.mc2_mask = None
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.assist_info_for_combine = None
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.ep_recv_counts = None
|
|
||||||
self.tp_recv_counts = None
|
|
||||||
|
|
||||||
def _pre_process(
|
def prepare(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""The target_pad_length is calculated in forward_context, here we pad the
|
||||||
|
hidden states and router logits. And if TP size > 1, we also need to split
|
||||||
|
the tensors accordingly.
|
||||||
|
"""
|
||||||
|
self.num_tokens, _ = hidden_states.shape
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
self.mc2_mask = forward_context.mc2_mask
|
||||||
|
target_pad_length = forward_context.padded_num_tokens
|
||||||
|
pad_size = target_pad_length - self.num_tokens
|
||||||
|
|
||||||
|
if pad_size > 0:
|
||||||
|
hidden_states = nn.functional.pad(hidden_states,
|
||||||
|
(0, 0, 0, pad_size))
|
||||||
|
router_logits = nn.functional.pad(router_logits,
|
||||||
|
(0, 0, 0, pad_size))
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
split_hidden_states = torch.tensor_split(hidden_states,
|
||||||
|
self.tp_size,
|
||||||
|
dim=0)
|
||||||
|
split_router_logits = torch.tensor_split(router_logits,
|
||||||
|
self.tp_size,
|
||||||
|
dim=0)
|
||||||
|
split_mc2_mask = torch.tensor_split(self.mc2_mask,
|
||||||
|
self.tp_size,
|
||||||
|
dim=0)
|
||||||
|
self.split_hidden_states = split_hidden_states
|
||||||
|
|
||||||
|
hidden_states = split_hidden_states[self.tp_rank]
|
||||||
|
router_logits = split_router_logits[self.tp_rank]
|
||||||
|
self.mc2_mask = split_mc2_mask[self.tp_rank]
|
||||||
|
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def finalize(self, hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool) -> torch.Tensor:
|
||||||
|
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||||
|
|
||||||
|
Also, unpad the hidden states if needed.
|
||||||
|
"""
|
||||||
|
if self.tp_size > 1:
|
||||||
|
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||||
|
self.moe_config.tp_group.device_group)
|
||||||
|
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||||
|
|
||||||
|
if self.num_tokens < hidden_states.shape[0]:
|
||||||
|
hidden_states = hidden_states[:self.num_tokens]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def permute(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
@@ -303,25 +416,24 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
# Store tensors needed for post_process
|
# Store tensors needed for post_process
|
||||||
self.topk_ids = topk_ids
|
self.topk_ids = topk_ids
|
||||||
self.topk_weights = topk_weights.to(torch.float32)
|
self.topk_weights = topk_weights.to(torch.float32)
|
||||||
self.mc2_mask = get_forward_context().mc2_mask
|
|
||||||
|
|
||||||
dispatch_kwargs = {
|
dispatch_kwargs = {
|
||||||
"x": hidden_states,
|
"x": hidden_states,
|
||||||
"expert_ids": self.topk_ids,
|
"expert_ids": self.topk_ids,
|
||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": self.global_num_experts,
|
"moe_expert_num": self.moe_config.num_experts,
|
||||||
"global_bs": 0,
|
"global_bs": 0,
|
||||||
"scales": None,
|
"scales": None,
|
||||||
"quant_mode": 0,
|
"quant_mode": 0,
|
||||||
"group_ep": self.moe_all_to_all_group_name,
|
"group_ep": self.mc2_comm_name,
|
||||||
"ep_world_size": self.ep_world_size,
|
"ep_world_size": self.moe_config.ep_size,
|
||||||
"ep_rank_id": self.ep_rank_id,
|
"ep_rank_id": self.moe_config.ep_rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.need_extra_args:
|
if self.need_extra_args:
|
||||||
dispatch_kwargs.update({
|
dispatch_kwargs.update({
|
||||||
"group_tp": self.moe_all_to_all_group_name,
|
"group_tp": self.mc2_comm_name,
|
||||||
"tp_world_size": 1,
|
"tp_world_size": 1,
|
||||||
"tp_rank_id": 0,
|
"tp_rank_id": 0,
|
||||||
})
|
})
|
||||||
@@ -345,20 +457,20 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||||||
|
|
||||||
def _post_process(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
combine_kwargs = {
|
combine_kwargs = {
|
||||||
"expand_x": mlp_output,
|
"expand_x": mlp_output,
|
||||||
"expert_ids": self.topk_ids,
|
"expert_ids": self.topk_ids,
|
||||||
"expert_scales": self.topk_weights,
|
"expert_scales": self.topk_weights,
|
||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": self.global_num_experts,
|
"moe_expert_num": self.moe_config.num_experts,
|
||||||
"global_bs": 0,
|
"global_bs": 0,
|
||||||
"ep_send_counts": self.ep_recv_counts,
|
"ep_send_counts": self.ep_recv_counts,
|
||||||
"group_ep": self.moe_all_to_all_group_name,
|
"group_ep": self.mc2_comm_name,
|
||||||
"ep_world_size": self.ep_world_size,
|
"ep_world_size": self.moe_config.ep_size,
|
||||||
"ep_rank_id": self.ep_rank_id,
|
"ep_rank_id": self.moe_config.ep_rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.enable_dispatch_v2:
|
if self.enable_dispatch_v2:
|
||||||
@@ -370,7 +482,7 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
if self.need_extra_args:
|
if self.need_extra_args:
|
||||||
combine_kwargs.update({
|
combine_kwargs.update({
|
||||||
"tp_send_counts": self.tp_recv_counts,
|
"tp_send_counts": self.tp_recv_counts,
|
||||||
"group_tp": self.moe_all_to_all_group_name,
|
"group_tp": self.mc2_comm_name,
|
||||||
"tp_world_size": 1,
|
"tp_world_size": 1,
|
||||||
"tp_rank_id": 0,
|
"tp_rank_id": 0,
|
||||||
})
|
})
|
||||||
@@ -382,68 +494,3 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||||
|
|
||||||
hidden_states[:] = combine(**combine_kwargs)
|
hidden_states[:] = combine(**combine_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def moe_comm_pre_process(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
||||||
"""This function is a wrapper for the pre_process method of the
|
|
||||||
MoECommMethod instance stored in the ForwardContext. So it can be
|
|
||||||
used as a custom op in the vllm framework.
|
|
||||||
"""
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
self = forward_context.moe_comm_method
|
|
||||||
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
|
|
||||||
num_experts)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_comm_pre_process_fake(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
||||||
"""This is a fake implementation of the pre_process method.
|
|
||||||
torch.compile will use this implementation to generate FX graph.
|
|
||||||
"""
|
|
||||||
top_k_num = topk_ids.shape[1]
|
|
||||||
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
|
|
||||||
expert_tokens = torch.zeros((num_experts, ),
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=hidden_states.device)
|
|
||||||
group_list_type = 0
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
|
||||||
|
|
||||||
|
|
||||||
def moe_comm_post_process(mlp_output: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor) -> None:
|
|
||||||
"""This function is a wrapper for the post_process method of the
|
|
||||||
MoECommMethod instance stored in the ForwardContext. So it can be
|
|
||||||
used as a custom op in the vllm framework.
|
|
||||||
"""
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
self = forward_context.moe_comm_method
|
|
||||||
self._post_process(mlp_output, hidden_states)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="moe_comm_pre_process",
|
|
||||||
op_func=moe_comm_pre_process,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=moe_comm_pre_process_fake,
|
|
||||||
dispatch_key="PrivateUse1",
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="moe_comm_post_process",
|
|
||||||
op_func=moe_comm_post_process,
|
|
||||||
mutates_args=["hidden_states"],
|
|
||||||
fake_impl=lambda x, y: None, # No-op for fake implementation
|
|
||||||
dispatch_key="PrivateUse1",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -497,6 +497,10 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
|||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
global _ROUTER_SCALE
|
global _ROUTER_SCALE
|
||||||
_ROUTER_SCALE = self.router_scale
|
_ROUTER_SCALE = self.router_scale
|
||||||
|
|
||||||
|
# TODO(angazenn): Does not support MC2 currently
|
||||||
|
get_forward_context().moe_comm_method_name = "allgathercommimpl"
|
||||||
|
|
||||||
if not use_h2p():
|
if not use_h2p():
|
||||||
final_hidden_states = self.experts.forward_impl(
|
final_hidden_states = self.experts.forward_impl(
|
||||||
hidden_states=hidden_states, router_logits=router_logits)
|
hidden_states=hidden_states, router_logits=router_logits)
|
||||||
|
|||||||
@@ -15,22 +15,84 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||||
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.layer import \
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
UnquantizedFusedMoEMethod
|
FusedMoE, UnquantizedFusedMoEMethod)
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
|
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||||
|
DummyCommImpl,
|
||||||
|
MC2CommImpl,
|
||||||
|
MoECommMethod)
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
|
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
|
||||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.utils import is_310p
|
||||||
|
|
||||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int4_w4a8: bool = False,
|
||||||
|
global_num_experts: Optional[int] = None,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale_bias: torch.Tensor = None,
|
||||||
|
w2_scale_bias: torch.Tensor = None,
|
||||||
|
moe_comm_method: Optional[MoECommMethod] = None,
|
||||||
|
# For TorchAir graph
|
||||||
|
is_torchair: bool = False,
|
||||||
|
# For Cube/Vector parallel
|
||||||
|
shared_experts: Optional[Any] = None,
|
||||||
|
quantized_x_for_share: Optional[Any] = None,
|
||||||
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
|
# For load balance
|
||||||
|
log2phy: torch.Tensor = None,
|
||||||
|
global_redundant_expert_num: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Check constraints
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], (
|
||||||
|
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||||
|
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert hidden_states.dtype in [
|
||||||
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
|
]
|
||||||
|
assert moe_comm_method is not None, "Missing communication context"
|
||||||
|
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
|
||||||
|
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute(
|
||||||
|
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
|
||||||
|
mlp_output = apply_mlp(
|
||||||
|
permuted_hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
expert_tokens,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
)
|
||||||
|
moe_comm_method.unpermute(mlp_output, hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
@@ -97,7 +159,7 @@ def forward_oot(
|
|||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
|
|
||||||
return unified_fused_experts(
|
return fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@@ -109,5 +171,112 @@ def forward_oot(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendFusedMoE(FusedMoE):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts,
|
||||||
|
top_k,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
params_dtype=None,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=True,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
num_expert_group=None,
|
||||||
|
topk_group=None,
|
||||||
|
quant_config=None,
|
||||||
|
tp_size=None,
|
||||||
|
ep_size=None,
|
||||||
|
dp_size=None,
|
||||||
|
prefix="",
|
||||||
|
custom_routing_function=None,
|
||||||
|
scoring_func="softmax",
|
||||||
|
e_score_correction_bias=None,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
activation="silu",
|
||||||
|
enable_eplb=False,
|
||||||
|
num_redundant_experts=0,
|
||||||
|
has_bias=False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_experts,
|
||||||
|
top_k,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
params_dtype,
|
||||||
|
reduce_results,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
quant_config,
|
||||||
|
tp_size,
|
||||||
|
ep_size,
|
||||||
|
dp_size,
|
||||||
|
prefix,
|
||||||
|
custom_routing_function,
|
||||||
|
scoring_func,
|
||||||
|
e_score_correction_bias,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
activation,
|
||||||
|
enable_eplb,
|
||||||
|
num_redundant_experts,
|
||||||
|
has_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.moe_config.tp_group = get_tp_group()
|
||||||
|
self.moe_config.dp_group = get_dp_group()
|
||||||
|
self.moe_config.ep_group = get_ep_group()
|
||||||
|
self.moe_config.mc2_group = get_mc2_group()
|
||||||
|
|
||||||
|
for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}:
|
||||||
|
setattr(
|
||||||
|
self, method.__name__.lower(),
|
||||||
|
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||||
|
|
||||||
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||||
|
if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl":
|
||||||
|
moe_comm_method_name = "allgathercommimpl"
|
||||||
|
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
||||||
|
|
||||||
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||||
|
hidden_states=hidden_states, router_logits=router_logits)
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
layer=self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
global_num_experts=self.global_num_experts,
|
||||||
|
expert_map=self.expert_map,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
activation=self.activation,
|
||||||
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
|
enable_eplb=self.enable_eplb,
|
||||||
|
expert_load_view=self.expert_load_view,
|
||||||
|
logical_to_physical_map=self.logical_to_physical_map,
|
||||||
|
logical_replica_count=self.logical_replica_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||||
|
hidden_states=final_hidden_states,
|
||||||
|
reduce_results=self.reduce_results)
|
||||||
|
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ from vllm_ascend.ascend_config import get_ascend_config
|
|||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.communication_op import \
|
from vllm_ascend.distributed.communication_op import \
|
||||||
data_parallel_reduce_scatter
|
data_parallel_reduce_scatter
|
||||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||||
@@ -58,60 +57,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
|||||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||||
|
|
||||||
|
|
||||||
def unified_fused_experts(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int4_w4a8: bool = False,
|
|
||||||
global_num_experts: Optional[int] = None,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_scale_bias: torch.Tensor = None,
|
|
||||||
w2_scale_bias: torch.Tensor = None,
|
|
||||||
moe_comm_method: Optional[MoECommMethod] = None,
|
|
||||||
# For Cube/Vector parallel
|
|
||||||
shared_experts: Optional[Any] = None,
|
|
||||||
quantized_x_for_share: Optional[Any] = None,
|
|
||||||
dynamic_scale_for_share: Optional[Any] = None,
|
|
||||||
# For load balance
|
|
||||||
log2phy: torch.Tensor = None,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Check constraints
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], (
|
|
||||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
||||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
||||||
assert hidden_states.dtype in [
|
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
|
||||||
]
|
|
||||||
assert moe_comm_method is not None, "Missing communication context"
|
|
||||||
|
|
||||||
num_experts = w1.shape[0]
|
|
||||||
|
|
||||||
permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process(
|
|
||||||
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
|
|
||||||
mlp_output = apply_mlp(
|
|
||||||
permuted_hidden_states,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
expert_tokens,
|
|
||||||
group_list_type=group_list_type,
|
|
||||||
)
|
|
||||||
torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||||
max_row_per_ep_rank: int, num_tokens: int,
|
max_row_per_ep_rank: int, num_tokens: int,
|
||||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
@@ -509,6 +509,9 @@ def register_ascend_customop():
|
|||||||
from vllm_ascend.ops.layernorm import AscendRMSNorm
|
from vllm_ascend.ops.layernorm import AscendRMSNorm
|
||||||
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
|
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
|
||||||
|
|
||||||
|
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||||
|
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
|
||||||
|
|
||||||
# NOTE: Keep this at last to ensure all custom actions are registered
|
# NOTE: Keep this at last to ensure all custom actions are registered
|
||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -85,9 +85,6 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
|||||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
|
||||||
DummyCommImpl,
|
|
||||||
MoECommMethod)
|
|
||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
@@ -368,13 +365,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||||
|
|
||||||
|
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size
|
||||||
self.reserved_mc2_mask = torch.zeros(
|
self.reserved_mc2_mask = torch.zeros(
|
||||||
512,
|
self.mc2_tokens_capacity,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.moe_comm_method = AllGatherCommImpl
|
self.moe_comm_method = "mc2"
|
||||||
|
self.fallback_moe_comm_method = "allgather"
|
||||||
|
self.dummy_moe_comm_method = "dummy"
|
||||||
|
|
||||||
def _use_aclgraph(self) -> bool:
|
def _use_aclgraph(self) -> bool:
|
||||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||||
@@ -1622,6 +1622,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
intermediate_tensors) = (self._prepare_inputs(
|
intermediate_tensors) = (self._prepare_inputs(
|
||||||
scheduler_output, intermediate_tensors))
|
scheduler_output, intermediate_tensors))
|
||||||
|
|
||||||
|
moe_comm_method = (self.moe_comm_method
|
||||||
|
if num_input_tokens <= self.mc2_tokens_capacity else
|
||||||
|
self.fallback_moe_comm_method)
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with ProfileExecuteDuration().capture_async("forward"):
|
with ProfileExecuteDuration().capture_async("forward"):
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
@@ -1631,8 +1635,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
with_prefill=self.with_prefill,
|
with_prefill=self.with_prefill,
|
||||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||||
moe_comm_method=self.moe_comm_method(
|
moe_comm_method=moe_comm_method,
|
||||||
self.device, self.dtype, self.model_config.hf_config),
|
|
||||||
num_actual_tokens=scheduler_output.
|
num_actual_tokens=scheduler_output.
|
||||||
total_num_scheduled_tokens):
|
total_num_scheduled_tokens):
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
@@ -1938,7 +1941,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
with_prefill: bool = False,
|
with_prefill: bool = False,
|
||||||
is_torchair_compile: bool = False,
|
is_torchair_compile: bool = False,
|
||||||
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
|
moe_comm_method: str = "dummy",
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
force_attention: bool = False,
|
force_attention: bool = False,
|
||||||
uniform_decode: bool = False,
|
uniform_decode: bool = False,
|
||||||
@@ -2061,8 +2064,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
in_profile_run=self.in_profile_run,
|
in_profile_run=self.in_profile_run,
|
||||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||||
moe_comm_method=moe_comm_method(
|
moe_comm_method=moe_comm_method,
|
||||||
self.device, self.dtype, self.model_config.hf_config),
|
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor):
|
batch_descriptor=batch_descriptor):
|
||||||
|
|||||||
Reference in New Issue
Block a user