[Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather(
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
sorted_hidden_states = dispatch_output.hidden_states
|
||||
group_list = dispatch_output.group_list
|
||||
group_list_type = dispatch_output.group_list_type
|
||||
context_metadata = dispatch_output.context_metadata
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
@@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather(
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
|
||||
torch.testing.assert_close(combined_output,
|
||||
torch.testing.assert_close(combined_output.routed_out,
|
||||
torch_output,
|
||||
atol=4e-2,
|
||||
rtol=1)
|
||||
@@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=True)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
dynamic_scale = dispatch_output["dynamic_scale"]
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
sorted_hidden_states = dispatch_output.hidden_states
|
||||
group_list = dispatch_output.group_list
|
||||
group_list_type = dispatch_output.group_list_type
|
||||
dynamic_scale = dispatch_output.dynamic_scale
|
||||
context_metadata = dispatch_output.context_metadata
|
||||
|
||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1,
|
||||
@@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
assert combined_output.shape == (m, k)
|
||||
assert combined_output.routed_out.shape == (m, k)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
@@ -8,6 +8,8 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
||||
TokenDispatchResult)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -178,12 +180,12 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_td_instance.token_dispatch.return_value = {
|
||||
"hidden_states": torch.randn(6, 8),
|
||||
"group_list": torch.tensor([2, 2, 2]),
|
||||
"group_list_type": 1
|
||||
}
|
||||
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1)
|
||||
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
||||
routed_out=torch.randn(4, 8))
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Mock unified_apply_mlp
|
||||
@@ -213,7 +215,7 @@ class TestMoECommMethod(TestBase):
|
||||
activation="silu")
|
||||
|
||||
# Verify result shape
|
||||
self.assertEqual(result.shape, (4, 8))
|
||||
self.assertEqual(result.routed_out.shape, (4, 8))
|
||||
|
||||
# Verify token_dispatch was called
|
||||
mock_td_instance.token_dispatch.assert_called_once()
|
||||
|
||||
@@ -97,8 +97,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
0) # group_list_type == 0
|
||||
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
@@ -149,43 +148,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
context_metadata)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
shared_experts = MagicMock()
|
||||
shared_experts.down_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
assist_info_for_combine = torch.arange(10)
|
||||
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": None,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": torch.randn(10, 128),
|
||||
"swiglu_out_scale": torch.randn(10, 1),
|
||||
"tp_recv_counts": tp_recv_counts
|
||||
}
|
||||
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
|
||||
hidden_states = torch.randn(10, 128)
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
result = self.dispatcher.token_combine(hidden_states,
|
||||
context_metadata)
|
||||
self.assertIsInstance(result, tuple)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
@@ -233,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -248,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_without_quant(self):
|
||||
kwargs = {
|
||||
@@ -268,7 +230,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
@@ -290,10 +252,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(results["hidden_states"])
|
||||
self.assertIsNotNone(results["group_list"])
|
||||
self.assertIsNotNone(results["dynamic_scale"])
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertIsNotNone(results.hidden_states)
|
||||
self.assertIsNotNone(results.group_list)
|
||||
self.assertIsNotNone(results.dynamic_scale)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
@@ -303,7 +265,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
hidden_states, context_metadata).routed_out
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
@@ -314,7 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
hidden_states, context_metadata).routed_out
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
@@ -326,7 +288,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
self.assertEqual(results["hidden_states"].shape, (6, 128))
|
||||
self.assertEqual(results.hidden_states.shape, (6, 128))
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
@@ -437,9 +399,9 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
hidden_states = torch.randn(16, 16)
|
||||
@@ -458,7 +420,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
self.assertEqual(output.routed_out.shape, (8, 16))
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
@@ -480,10 +442,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
@@ -508,10 +470,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_log2phy(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
@@ -530,6 +492,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
@@ -37,6 +37,7 @@ from vllm_ascend.flash_common3_context import (get_flash_common3_context,
|
||||
set_flash_common3_context)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
FusedExpertsResult,
|
||||
setup_moe_comm_method)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||
@@ -325,7 +326,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
pertoken_scale = None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -350,25 +351,25 @@ class AscendFusedMoE(FusedMoE):
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask)
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
if self.dynamic_eplb:
|
||||
if self.dynamic_eplb:
|
||||
expert_tokens = fused_experts_results.expert_tokens
|
||||
group_list_type = fused_experts_results.group_list_type
|
||||
assert expert_tokens is not None and group_list_type is not None, \
|
||||
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
|
||||
moe_load_stream = moe_load_async_stream()
|
||||
cur_stream = torch.npu.current_stream()
|
||||
moe_load_stream.wait_stream(cur_stream)
|
||||
with npu_stream_switch(moe_load_stream):
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
cur_stream.wait_stream(moe_load_stream)
|
||||
|
||||
moe_load_stream = moe_load_async_stream()
|
||||
cur_stream = torch.npu.current_stream()
|
||||
|
||||
moe_load_stream.wait_stream(cur_stream)
|
||||
with npu_stream_switch(moe_load_stream):
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
cur_stream.wait_stream(moe_load_stream)
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
routed_out = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
return final_hidden_states
|
||||
return routed_out
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
@@ -439,7 +440,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
else:
|
||||
set_flash_common3_context(shared_experts=self._shared_experts)
|
||||
|
||||
fused_output = AscendFusedMoE.forward_impl(
|
||||
routed_out = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -462,4 +463,4 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
assert fc3_context is not None
|
||||
shared_out = fc3_context.shared_out
|
||||
|
||||
return shared_out, fused_output
|
||||
return shared_out, routed_out
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
@@ -26,11 +27,11 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, QuantType)
|
||||
PrepareAndFinalize, PrepareAndFinalizeWithAll2All,
|
||||
PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2)
|
||||
MoETokenDispatcher, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
@@ -47,6 +48,14 @@ def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedExpertsResult:
|
||||
routed_out: torch.Tensor
|
||||
# For dynamic_eplb
|
||||
group_list_type: int | None = None
|
||||
expert_tokens: torch.Tensor | None = None
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
@@ -118,7 +127,7 @@ class MoECommMethod(ABC):
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
@@ -134,43 +143,41 @@ class MoECommMethod(ABC):
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
pertoken_scale=pertoken_scale)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=dispatch_results.group_list,
|
||||
dynamic_scale=dispatch_results.dynamic_scale,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
topk_scales=dispatch_results.topk_scales,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=expert_tokens,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=use_int8_w8a8
|
||||
or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output,
|
||||
context_metadata=dispatch_results.context_metadata)
|
||||
|
||||
final_hidden_states = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output, context_metadata=context_metadata)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (final_hidden_states, group_list_type, expert_tokens)
|
||||
|
||||
return final_hidden_states
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list)
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self):
|
||||
def _get_token_dispatcher(self) -> MoETokenDispatcher:
|
||||
raise NotImplementedError(
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_prepare_finalize(self):
|
||||
def _get_prepare_finalize(self) -> PrepareAndFinalize:
|
||||
raise NotImplementedError(
|
||||
"_get_prepare_finalize function not implemented.")
|
||||
|
||||
@@ -292,9 +299,11 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
w1_scale is None or w2_scale is None
|
||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
|
||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
out = torch.empty_like(hidden_states)
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||
x=hidden_states,
|
||||
weight1=w1[0],
|
||||
weight2=w2[0],
|
||||
@@ -308,7 +317,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
)
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
assert expert_map is not None, "expert_map cannot be None."
|
||||
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
x=hidden_states,
|
||||
expert_ids=topk_ids,
|
||||
gmm1_permuted_weight=w1[0],
|
||||
@@ -325,4 +334,4 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return out
|
||||
return FusedExpertsResult(routed_out=out)
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -35,6 +36,21 @@ from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenDispatchResult:
|
||||
hidden_states: torch.Tensor
|
||||
group_list: torch.Tensor
|
||||
group_list_type: int
|
||||
dynamic_scale: torch.Tensor | None = field(default=None)
|
||||
topk_scales: torch.Tensor | None = field(default=None)
|
||||
context_metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenCombineResult:
|
||||
routed_out: torch.Tensor
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
@@ -74,14 +90,14 @@ class MoETokenDispatcher(ABC):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> TokenDispatchResult:
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
bias: torch.Tensor | None = None) -> TokenCombineResult:
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
@@ -207,24 +223,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
|
||||
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
|
||||
|
||||
# Handle shared experts (store intermediate results in local vars, not self)
|
||||
shared_act = None
|
||||
swiglu_out_scale = None
|
||||
if with_quant:
|
||||
if shared_experts is not None:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
shared_act, swiglu_out_scale = shared_act_out[
|
||||
0], shared_act_out[1]
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
@@ -233,20 +231,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_recv_counts": tp_recv_counts,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": shared_act,
|
||||
"swiglu_out_scale": swiglu_out_scale,
|
||||
"expand_scales": expand_scales
|
||||
}
|
||||
|
||||
group_list_type = 0
|
||||
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"context_metadata": context_metadata
|
||||
}
|
||||
return TokenDispatchResult(hidden_states=expand_x,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list=expert_token_nums,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict):
|
||||
@@ -300,12 +294,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None,
|
||||
):
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
|
||||
@@ -313,20 +302,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
|
||||
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
# Handle shared experts from metadata
|
||||
shared_experts = context_metadata["shared_experts"]
|
||||
if shared_experts is None:
|
||||
return combined_output
|
||||
|
||||
shared_act = context_metadata["shared_act"]
|
||||
if self.with_quant:
|
||||
swiglu_out_scale = context_metadata["swiglu_out_scale"]
|
||||
shared_hidden_states, _ = shared_experts.down_proj(
|
||||
(shared_act, swiglu_out_scale))
|
||||
else:
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
|
||||
return combined_output, shared_hidden_states
|
||||
return TokenCombineResult(routed_out=combined_output)
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
@@ -401,18 +377,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
"topk_weights": topk_weights,
|
||||
"expanded_row_idx": expanded_row_idx
|
||||
}
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||
"context_metadata": context_metadata
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
return TokenDispatchResult(
|
||||
hidden_states=sorted_hidden_states,
|
||||
dynamic_scale=pertoken_scale if self.with_quant else None,
|
||||
group_list=expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
)
|
||||
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert self.original_shape is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
@@ -422,7 +396,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
return final_hidden_states
|
||||
return TokenCombineResult(routed_out=final_hidden_states)
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
@@ -530,20 +504,15 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
reversed_global_input_permutation_mapping
|
||||
}
|
||||
|
||||
return {
|
||||
"hidden_states": global_input_tokens,
|
||||
"group_list": tokens_per_expert,
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": dynamic_scale_final,
|
||||
"context_metadata": context_metadata,
|
||||
}
|
||||
return TokenDispatchResult(
|
||||
hidden_states=global_input_tokens,
|
||||
dynamic_scale=dynamic_scale_final,
|
||||
group_list=tokens_per_expert,
|
||||
group_list_type=1,
|
||||
context_metadata=context_metadata,
|
||||
)
|
||||
|
||||
def token_combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None,
|
||||
):
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
# 1. Preprocess using metadata
|
||||
@@ -564,7 +533,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
output = self._combine_postprocess(permutated_local_input_tokens,
|
||||
context_metadata)
|
||||
|
||||
return output
|
||||
return TokenCombineResult(routed_out=output)
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
|
||||
Reference in New Issue
Block a user