refactor allgather/mc2-related fused_experts (#2369)
### What this PR does / why we need it?
refactor allgather/mc2-related fused_experts
- vLLM version: v0.10.0
- vLLM main:
de7b67a023
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -15,12 +15,17 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
|
||||
|
||||
@@ -63,3 +68,289 @@ class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
|
||||
assert dispatcher.ep_rank == 0
|
||||
assert dispatcher.ep_size == 2
|
||||
assert dispatcher.overlap_stream is not None
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mc2_group = mock.MagicMock()
|
||||
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
self.rank_group_patch = mock.patch("torch.distributed.get_rank",
|
||||
return_value=0)
|
||||
self.rank_group_patch.start()
|
||||
|
||||
# Mock get_forward_context().mc2_mask
|
||||
self.forward_context = mock.MagicMock()
|
||||
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
self.forward_context_patch = mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
|
||||
return_value=self.forward_context)
|
||||
self.forward_context_patch.start()
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
# Mock get_ascend_config()
|
||||
self.ascend_config = mock.MagicMock()
|
||||
self.ascend_config.torchair_graph_config.enabled = False
|
||||
self.ascend_config_patch = mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
|
||||
return_value=self.ascend_config)
|
||||
self.ascend_config_patch.start()
|
||||
|
||||
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
||||
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
self.mc2_group_patch.stop()
|
||||
self.forward_context_patch.stop()
|
||||
self.ascend_soc_version_patch.stop()
|
||||
self.ascend_config_patch.stop()
|
||||
|
||||
def test_init(self):
|
||||
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
|
||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||
self.assertFalse(self.dispatcher.torchair_graph_enabled)
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
self.assertTrue(self.dispatcher.a3_need_extra_args)
|
||||
|
||||
def test_get_permute_mc2_kwargs_without_quant(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
kwargs = self.dispatcher.get_permute_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
|
||||
def test_token_permutation_dispatch(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_weights = torch.randn(10, 1)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) *
|
||||
5) as mock_dispatch:
|
||||
output = self.dispatcher.token_permutation(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output[0], 1) # group_list_type == 1
|
||||
|
||||
def test_token_permutation_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = mock.MagicMock()
|
||||
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
self.topk_weights = torch.randn(10, 1)
|
||||
|
||||
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
with mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
|
||||
autospec=True):
|
||||
with mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
|
||||
autospec=True) as mock_wait:
|
||||
self.dispatcher.token_permutation(
|
||||
self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
mock_wait.assert_any_call(self.hidden_states,
|
||||
self.topk_weights)
|
||||
|
||||
def test_get_unpermute_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
|
||||
kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_unpermutation_with_shared_experts(self):
|
||||
self.dispatcher.shared_experts = mock.MagicMock()
|
||||
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
|
||||
10, 128), torch.tensor(1.0))
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
|
||||
with mock.patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
with mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
|
||||
autospec=True):
|
||||
with mock.patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
|
||||
autospec=True):
|
||||
self.dispatcher.token_unpermutation(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock dependencies
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
"with_quant": False,
|
||||
}
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = mock.patch(
|
||||
'torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = mock.patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = mock.patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
|
||||
def test_token_permutation_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
|
||||
self.assertEqual(group_list_type, 0)
|
||||
|
||||
def test_token_permutation_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
"with_quant": True,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
|
||||
# Verify quant mode returns group_list_type=1
|
||||
self.assertEqual(group_list_type, 0)
|
||||
|
||||
def test_token_unpermutation_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_unpermutation(
|
||||
hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_unpermutation_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_unpermutation(
|
||||
hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_permutation_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
self.assertEqual(sorted_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_permutation_invalid_topk_when_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
self.dispatcher.token_permutation(
|
||||
hidden_states, topk_weights,
|
||||
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
|
||||
|
||||
@@ -20,17 +20,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
|
||||
all_to_all_sp2hp, gather_from_sequence_parallel_region,
|
||||
reduce_scatter_last_dim_to_tensor_parallel_region)
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoEDispatcherConfig:
|
||||
@@ -451,3 +458,505 @@ class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
|
||||
return output, None
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.top_k = kwargs.get("top_k")
|
||||
self.num_experts = kwargs.get("num_experts")
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@abstractmethod
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
raise NotImplementedError("Restore function not implemented.")
|
||||
|
||||
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.with_quant = kwargs.get("with_quant")
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
or self.torchair_graph_enabled)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.output = None
|
||||
self.dynamic_scale = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.shared_act = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
|
||||
def get_permute_mc2_kwargs(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0):
|
||||
quant_mode = 0
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.with_quant:
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
else:
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
):
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
self.shared_experts = shared_experts
|
||||
|
||||
kwargs_mc2 = self.get_permute_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
global_redundant_expert_num)
|
||||
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, self.dynamic_scale, self.assist_info_for_combine, \
|
||||
expert_token_nums, self.ep_recv_counts = self.output[0:5]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_gate_up, expand_x)
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(
|
||||
hidden_states)
|
||||
npu_wait_tensor(shared_gate_up, expand_x)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 1
|
||||
return group_list_type, expand_x, expert_token_nums
|
||||
|
||||
def get_unpermute_mc_kwargs(self, hidden_states: torch.Tensor):
|
||||
assert self.expert_map is not None
|
||||
assert self.topk_weights is not None
|
||||
assert self.topk_ids is not None
|
||||
assert self.output is not None
|
||||
moe_expert_num = len(self.expert_map)
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
tp_recv_counts = self.output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
self.assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": self.assist_info_for_combine,
|
||||
})
|
||||
if self.need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
|
||||
kwargs_mc2 = self.get_unpermute_mc_kwargs(hidden_states)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
if self.with_quant:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(self.shared_act, hidden_states)
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(self.shared_act, hidden_states)
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = kwargs.get(
|
||||
"apply_router_weight_on_input")
|
||||
self.top_k = kwargs.get("top_k")
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
ep_size = kwargs.get("ep_size")
|
||||
if ep_size is not None:
|
||||
self.num_experts_local = self.num_experts // ep_size
|
||||
self.with_quant = kwargs.get("with_quant")
|
||||
self.sorted_weights = None
|
||||
self.expanded_row_idx = None
|
||||
self.sorted_token_indices = None
|
||||
self.original_shape = None
|
||||
self.mask = None
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
):
|
||||
self.original_shape = hidden_states.shape
|
||||
# assert len(original_shape) == 2
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
# ], "Only float32, float16, and bfsloat16 are supported"
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
token_indices = (torch.arange(
|
||||
num_tokens, device=device,
|
||||
dtype=torch.int64).unsqueeze(1).expand(-1,
|
||||
self.top_k).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = expert_map[experts_flat]
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
self.mask = local_experts_flat != -1
|
||||
filtered_weights = torch.where(
|
||||
self.mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(dtype)
|
||||
filtered_experts = torch.where(
|
||||
self.mask, local_experts_flat,
|
||||
torch.full_like(local_experts_flat,
|
||||
self.num_experts_local)).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(self.num_experts_local + 1,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
|
||||
ones)
|
||||
token_counts = token_counts[:self.num_experts_local]
|
||||
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
if self.with_quant:
|
||||
group_list_type = 1
|
||||
else:
|
||||
expert_tokens = torch.cumsum(token_counts,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list_type = 0
|
||||
else:
|
||||
row_idx_len = num_tokens * self.top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=device).view(self.top_k,
|
||||
-1).permute(
|
||||
1, 0).contiguous())
|
||||
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=active_num)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, self.num_experts_local)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
return group_list_type, sorted_hidden_states, expert_tokens
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.mask is not None
|
||||
assert self.sorted_token_indices is not None
|
||||
assert self.sorted_weights is not None
|
||||
assert self.original_shape is not None
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
if self.expert_map is not None:
|
||||
weighted_down_out = hidden_states * \
|
||||
self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros(*self.original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = self.mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, self.sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
valid_output)
|
||||
else:
|
||||
if self.with_quant:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=self.topk_weights,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(
|
||||
self.original_shape)
|
||||
else:
|
||||
scales = torch.ones_like(
|
||||
self.topk_weights
|
||||
) if self.apply_router_weight_on_input else self.topk_weights
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=scales,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# mypy: disable-error-code="override"
|
||||
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MoETokenDispatcher, self).__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = kwargs.get(
|
||||
"apply_router_weight_on_input")
|
||||
ep_size = kwargs.get("ep_size")
|
||||
self.local_ep = ep_size
|
||||
self.top_k = kwargs.get("top_k")
|
||||
assert self.local_ep is not None
|
||||
self.local_num_experts = self.num_experts // self.local_ep
|
||||
self.local_num_group = self.top_k // self.local_ep
|
||||
self.bsz = None
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
):
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
||||
self.sorted_hidden_states = hidden_states.index_select(
|
||||
0, self.sorted_topk_ids // self.local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
self.local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (
|
||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
self.topk_scales = topk_weights.view(-1).index_select(
|
||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
return hidden_states, group_list
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.local_ep is not None
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
self.bsz, self.top_k // self.local_ep, -1).sum(1)
|
||||
return final_hidden_states
|
||||
|
||||
Reference in New Issue
Block a user