[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -245,6 +245,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
|
|||||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
||||||
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
||||||
|
@wait_until_npu_memory_free()
|
||||||
def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
||||||
short_example_prompts = [
|
short_example_prompts = [
|
||||||
"Hello ",
|
"Hello ",
|
||||||
@@ -272,6 +273,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
|||||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
||||||
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
||||||
|
@wait_until_npu_memory_free()
|
||||||
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
|
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
|
||||||
short_example_prompts = [
|
short_example_prompts = [
|
||||||
"Hello ",
|
"Hello ",
|
||||||
|
|||||||
@@ -28,11 +28,17 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import (
|
from vllm_ascend.ops.fused_moe.experts_selector import check_npu_moe_gating_top_k, select_experts
|
||||||
check_npu_moe_gating_top_k, select_experts)
|
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
TokenDispatcherWithAllGather
|
build_fused_experts_input,
|
||||||
|
build_mlp_compute_input,
|
||||||
|
MoEQuantParams,
|
||||||
|
MoERoutingParams,
|
||||||
|
MoETokenDispatchInput,
|
||||||
|
)
|
||||||
|
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
EP_SIZE = [1]
|
EP_SIZE = [1]
|
||||||
@@ -83,10 +89,8 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
|||||||
for i in range(w1.shape[0]):
|
for i in range(w1.shape[0]):
|
||||||
mask = topk_ids == i
|
mask = topk_ids == i
|
||||||
if mask.sum():
|
if mask.sum():
|
||||||
out[mask] = SiluAndMul()(
|
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
return (out.view(B, -1, w2.shape[1]) * topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
return (out.view(B, -1, w2.shape[1]) *
|
|
||||||
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||||
@@ -129,36 +133,41 @@ def test_token_dispatcher_with_all_gather(
|
|||||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||||
|
|
||||||
apply_router_weight_on_input = False
|
apply_router_weight_on_input = False
|
||||||
dispatch_output = dispatcher.token_dispatch(
|
token_dispatch_output = dispatcher.token_dispatch(
|
||||||
hidden_states=a,
|
token_dispatch_input=MoETokenDispatchInput(
|
||||||
topk_weights=topk_weights,
|
hidden_states=a,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
expert_map=expert_map,
|
topk_ids=topk_ids,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
routing=MoERoutingParams(
|
||||||
|
expert_map=expert_map,
|
||||||
|
global_redundant_expert_num=0,
|
||||||
|
mc2_mask=None,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
),
|
||||||
|
quant=MoEQuantParams(quant_type=QuantType.NONE),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
sorted_hidden_states = dispatch_output.hidden_states
|
sorted_hidden_states = token_dispatch_output.hidden_states
|
||||||
group_list = dispatch_output.group_list
|
group_list = token_dispatch_output.group_list
|
||||||
group_list_type = dispatch_output.group_list_type
|
group_list_type = token_dispatch_output.group_list_type
|
||||||
context_metadata = dispatch_output.context_metadata
|
combine_metadata = token_dispatch_output.combine_metadata
|
||||||
|
|
||||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
expert_output = apply_mlp(
|
||||||
w1=w1_local,
|
hidden_states=sorted_hidden_states,
|
||||||
w2=w2_local,
|
w1=w1_local,
|
||||||
group_list=group_list,
|
w2=w2_local,
|
||||||
group_list_type=group_list_type)
|
group_list=group_list,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
)
|
||||||
|
|
||||||
combined_output = dispatcher.token_combine(
|
combined_output = dispatcher.token_combine(
|
||||||
hidden_states=expert_output,
|
hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
|
||||||
context_metadata=context_metadata,
|
)
|
||||||
bias=None)
|
|
||||||
|
|
||||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map)
|
||||||
expert_map)
|
|
||||||
|
|
||||||
torch.testing.assert_close(combined_output.routed_out,
|
torch.testing.assert_close(combined_output, torch_output, atol=4e-2, rtol=1)
|
||||||
torch_output,
|
|
||||||
atol=4e-2,
|
|
||||||
rtol=1)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
@@ -184,8 +193,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
|||||||
):
|
):
|
||||||
context_mock = MagicMock()
|
context_mock = MagicMock()
|
||||||
context_mock.fused_moe_state = 0
|
context_mock.fused_moe_state = 0
|
||||||
with patch("vllm_ascend.ascend_forward_context.get_forward_context",
|
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", return_value=context_mock):
|
||||||
return_value=context_mock):
|
|
||||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||||
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
|
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
|
||||||
@@ -208,34 +216,44 @@ def test_token_dispatcher_with_all_gather_quant(
|
|||||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||||
|
|
||||||
apply_router_weight_on_input = False
|
apply_router_weight_on_input = False
|
||||||
dispatch_output = dispatcher.token_dispatch(
|
token_dispatch_output = dispatcher.token_dispatch(
|
||||||
hidden_states=a,
|
token_dispatch_input=MoETokenDispatchInput(
|
||||||
topk_weights=topk_weights,
|
hidden_states=a,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
expert_map=expert_map,
|
topk_ids=topk_ids,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
routing=MoERoutingParams(
|
||||||
with_quant=True)
|
expert_map=expert_map,
|
||||||
|
global_redundant_expert_num=0,
|
||||||
|
mc2_mask=None,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
),
|
||||||
|
quant=MoEQuantParams(quant_type=QuantType.W8A8),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
sorted_hidden_states = dispatch_output.hidden_states
|
combine_metadata = token_dispatch_output.combine_metadata
|
||||||
group_list = dispatch_output.group_list
|
|
||||||
group_list_type = dispatch_output.group_list_type
|
|
||||||
dynamic_scale = dispatch_output.dynamic_scale
|
|
||||||
context_metadata = dispatch_output.context_metadata
|
|
||||||
|
|
||||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
mlp_compute_input = build_mlp_compute_input(
|
||||||
w1=w1,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1_scale=w1_scale,
|
hidden_states=a,
|
||||||
w2=w2,
|
topk_weights=topk_weights,
|
||||||
w2_scale=w2_scale,
|
topk_ids=topk_ids,
|
||||||
group_list=group_list,
|
w1=w1,
|
||||||
group_list_type=group_list_type,
|
w2=w2,
|
||||||
dynamic_scale=dynamic_scale,
|
quant_type=QuantType.W8A8,
|
||||||
with_quant=True)
|
dynamic_eplb=False,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
),
|
||||||
|
token_dispatch_output=token_dispatch_output,
|
||||||
|
use_fusion_ops=False,
|
||||||
|
)
|
||||||
|
expert_output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||||
combined_output = dispatcher.token_combine(
|
combined_output = dispatcher.token_combine(
|
||||||
hidden_states=expert_output,
|
hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
|
||||||
context_metadata=context_metadata,
|
)
|
||||||
bias=None)
|
assert combined_output.shape == (m, k)
|
||||||
assert combined_output.routed_out.shape == (m, k)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
@@ -271,25 +289,20 @@ def test_select_experts(
|
|||||||
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
||||||
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
||||||
|
|
||||||
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
|
e_score_correction_bias = torch.randn(e, device=device, dtype=dtype) if with_e_correction else None
|
||||||
if with_e_correction else None)
|
|
||||||
|
|
||||||
custom_routing_function = None
|
custom_routing_function = None
|
||||||
if custom_routing:
|
if custom_routing:
|
||||||
custom_routing_function = MagicMock()
|
custom_routing_function = MagicMock()
|
||||||
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
||||||
mock_ids = torch.randint(0,
|
mock_ids = torch.randint(0, e, (m, topk), device=device, dtype=torch.int32)
|
||||||
e, (m, topk),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||||
|
|
||||||
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
with (
|
||||||
) as mock_native_grouped_topk, \
|
patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk") as mock_native_grouped_topk,
|
||||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
|
||||||
return_value=MagicMock()):
|
):
|
||||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(x)
|
||||||
x)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -305,8 +318,8 @@ def test_select_experts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
||||||
hidden_states, topk, renormalize, topk_group, num_expert_group,
|
hidden_states, topk, renormalize, topk_group, num_expert_group, scoring_func, custom_routing_function
|
||||||
scoring_func, custom_routing_function)
|
)
|
||||||
if not call_moe_gatingtopk and use_grouped_topk:
|
if not call_moe_gatingtopk and use_grouped_topk:
|
||||||
mock_native_grouped_topk.assert_called_once()
|
mock_native_grouped_topk.assert_called_once()
|
||||||
else:
|
else:
|
||||||
@@ -323,16 +336,18 @@ def test_select_experts(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICE)
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
def test_select_experts_invalid_scoring_func(device: str):
|
def test_select_experts_invalid_scoring_func(device: str):
|
||||||
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
with (
|
||||||
return_value=MagicMock()), \
|
patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
|
||||||
pytest.raises(ValueError,
|
pytest.raises(ValueError, match="Unsupported scoring function: invalid"),
|
||||||
match="Unsupported scoring function: invalid"):
|
):
|
||||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
select_experts(
|
||||||
router_logits=torch.randn(1, 8, device=device),
|
hidden_states=torch.randn(1, 128, device=device),
|
||||||
top_k=2,
|
router_logits=torch.randn(1, 8, device=device),
|
||||||
use_grouped_topk=False,
|
top_k=2,
|
||||||
renormalize=False,
|
use_grouped_topk=False,
|
||||||
scoring_func="invalid")
|
renormalize=False,
|
||||||
|
scoring_func="invalid",
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
|
|||||||
@@ -19,6 +19,38 @@ import torch
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEMlpComputeInput,
|
||||||
|
MoEQuantParams,
|
||||||
|
MoEWeights,
|
||||||
|
)
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
|
|
||||||
|
def build_mlp_compute_input_fixture(
|
||||||
|
*,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
group_list: torch.Tensor,
|
||||||
|
with_quant: bool,
|
||||||
|
w1_scale: torch.Tensor | None = None,
|
||||||
|
w2_scale: torch.Tensor | None = None,
|
||||||
|
group_list_type: int = 1,
|
||||||
|
) -> MoEMlpComputeInput:
|
||||||
|
return MoEMlpComputeInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
group_list=group_list,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
dynamic_scale=None,
|
||||||
|
topk_scales=None,
|
||||||
|
weights=MoEWeights(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale),
|
||||||
|
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
|
||||||
|
fusion=False,
|
||||||
|
activation="silu",
|
||||||
|
need_trans=False,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestUnifiedApplyMLP310(TestBase):
|
class TestUnifiedApplyMLP310(TestBase):
|
||||||
@@ -38,14 +70,13 @@ class TestUnifiedApplyMLP310(TestBase):
|
|||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
|
|
||||||
result = unified_apply_mlp(
|
result = unified_apply_mlp(
|
||||||
hidden_states=hidden_states,
|
mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=None,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=None,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=False,
|
||||||
group_list_type=1,
|
)
|
||||||
with_quant=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||||
@@ -94,14 +125,15 @@ class TestUnifiedApplyMLP310(TestBase):
|
|||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
|
|
||||||
result = unified_apply_mlp(
|
result = unified_apply_mlp(
|
||||||
hidden_states=hidden_states,
|
mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=w1_scale,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=w2_scale,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=True,
|
||||||
group_list_type=1,
|
w1_scale=w1_scale,
|
||||||
with_quant=True,
|
w2_scale=w2_scale,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_cumsum.assert_called_once()
|
mock_cumsum.assert_called_once()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
from typing import List, TypedDict
|
from typing import TypedDict
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -20,12 +20,19 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||||
unified_apply_mlp)
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEMlpComputeInput,
|
||||||
|
MoEPrepareOutput,
|
||||||
|
MoEQuantParams,
|
||||||
|
MoEWeights,
|
||||||
|
)
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
||||||
|
|
||||||
adapt_patch(True)
|
adapt_patch(True)
|
||||||
@@ -54,6 +61,51 @@ def mock_npu_format_cast(weight_data, format):
|
|||||||
return weight_data
|
return weight_data
|
||||||
|
|
||||||
|
|
||||||
|
def build_mlp_compute_input_fixture(
|
||||||
|
*,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor | list[torch.Tensor],
|
||||||
|
w2: torch.Tensor | list[torch.Tensor],
|
||||||
|
group_list: torch.Tensor,
|
||||||
|
with_quant: bool,
|
||||||
|
group_list_type: int = 1,
|
||||||
|
dynamic_scale: torch.Tensor | None = None,
|
||||||
|
topk_scales: torch.Tensor | None = None,
|
||||||
|
w1_scale: torch.Tensor | list[torch.Tensor] | None = None,
|
||||||
|
w2_scale: torch.Tensor | list[torch.Tensor] | None = None,
|
||||||
|
w1_scale_bias: torch.Tensor | None = None,
|
||||||
|
w2_scale_bias: torch.Tensor | None = None,
|
||||||
|
w1_offset: torch.Tensor | None = None,
|
||||||
|
w2_offset: torch.Tensor | None = None,
|
||||||
|
fusion: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
need_trans: bool = True,
|
||||||
|
dynamic_eplb: bool = False,
|
||||||
|
) -> MoEMlpComputeInput:
|
||||||
|
return MoEMlpComputeInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
group_list=group_list,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
dynamic_scale=dynamic_scale,
|
||||||
|
topk_scales=topk_scales,
|
||||||
|
weights=MoEWeights(
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
w1_scale_bias=w1_scale_bias,
|
||||||
|
w2_scale_bias=w2_scale_bias,
|
||||||
|
w1_offset=w1_offset,
|
||||||
|
w2_offset=w2_offset,
|
||||||
|
),
|
||||||
|
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
|
||||||
|
fusion=fusion,
|
||||||
|
activation=activation,
|
||||||
|
need_trans=need_trans,
|
||||||
|
dynamic_eplb=dynamic_eplb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_vllm_config_mock(mocker: MockerFixture):
|
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||||
mock_hf_config = MagicMock()
|
mock_hf_config = MagicMock()
|
||||||
@@ -77,7 +129,13 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
mock_moe_comm_method = MagicMock()
|
mock_moe_comm_method = MagicMock()
|
||||||
|
|
||||||
def mock_prepare(hidden_states, router_logits, **kwargs):
|
def mock_prepare(hidden_states, router_logits, **kwargs):
|
||||||
return hidden_states, router_logits
|
return MoEPrepareOutput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
mc2_mask=kwargs.get("mc2_mask"),
|
||||||
|
padded_hidden_states_shape=None,
|
||||||
|
pertoken_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
||||||
|
|
||||||
@@ -204,18 +262,18 @@ def moe_method(mock_dist_env):
|
|||||||
|
|
||||||
class Device(TypedDict):
|
class Device(TypedDict):
|
||||||
device_id: int
|
device_id: int
|
||||||
device_expert: List[int]
|
device_expert: list[int]
|
||||||
|
|
||||||
|
|
||||||
class Layer(TypedDict):
|
class Layer(TypedDict):
|
||||||
layer_id: int
|
layer_id: int
|
||||||
device_count: int
|
device_count: int
|
||||||
device_list: List[Device]
|
device_list: list[Device]
|
||||||
|
|
||||||
|
|
||||||
class MockData(TypedDict):
|
class MockData(TypedDict):
|
||||||
moe_layer_count: int
|
moe_layer_count: int
|
||||||
layer_list: List[Layer]
|
layer_list: list[Layer]
|
||||||
|
|
||||||
|
|
||||||
class MockQuantMethod(nn.Module):
|
class MockQuantMethod(nn.Module):
|
||||||
@@ -338,18 +396,15 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
|
|
||||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=w1_scale,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=w2_scale,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=True,
|
||||||
dynamic_scale=None,
|
w1_scale=w1_scale,
|
||||||
group_list_type=1,
|
w2_scale=w2_scale,
|
||||||
w1_scale_bias=None,
|
))
|
||||||
w2_scale_bias=None,
|
|
||||||
topk_scales=None,
|
|
||||||
with_quant=True)
|
|
||||||
|
|
||||||
mock_get_forward_context.assert_called()
|
mock_get_forward_context.assert_called()
|
||||||
|
|
||||||
@@ -383,18 +438,14 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||||
|
|
||||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=None,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=None,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=False,
|
||||||
dynamic_scale=None,
|
topk_scales=topk_scales,
|
||||||
group_list_type=1,
|
))
|
||||||
w1_scale_bias=None,
|
|
||||||
w2_scale_bias=None,
|
|
||||||
topk_scales=topk_scales,
|
|
||||||
with_quant=False)
|
|
||||||
|
|
||||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||||
mock_npu_swiglu.assert_called_once()
|
mock_npu_swiglu.assert_called_once()
|
||||||
@@ -445,18 +496,18 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||||
|
|
||||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=w1_scale,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=w2_scale,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=True,
|
||||||
dynamic_scale=provided_dynamic_scale,
|
dynamic_scale=provided_dynamic_scale,
|
||||||
group_list_type=1,
|
w1_scale=w1_scale,
|
||||||
w1_scale_bias=w1_scale_bias,
|
w2_scale=w2_scale,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
topk_scales=None,
|
w2_scale_bias=w2_scale_bias,
|
||||||
with_quant=True)
|
))
|
||||||
|
|
||||||
mock_get_forward_context.assert_called()
|
mock_get_forward_context.assert_called()
|
||||||
|
|
||||||
@@ -490,18 +541,14 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||||
|
|
||||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=None,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=None,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=False,
|
||||||
dynamic_scale=None,
|
topk_scales=topk_scales,
|
||||||
group_list_type=1,
|
))
|
||||||
w1_scale_bias=None,
|
|
||||||
w2_scale_bias=None,
|
|
||||||
topk_scales=topk_scales,
|
|
||||||
with_quant=False)
|
|
||||||
|
|
||||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||||
mock_npu_swiglu.assert_called_once()
|
mock_npu_swiglu.assert_called_once()
|
||||||
@@ -556,19 +603,19 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||||
|
|
||||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||||
w1=w1,
|
hidden_states=hidden_states,
|
||||||
w1_scale=w1_scale,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=w2_scale,
|
group_list=group_list,
|
||||||
group_list=group_list,
|
with_quant=True,
|
||||||
dynamic_scale=provided_dynamic_scale,
|
dynamic_scale=provided_dynamic_scale,
|
||||||
group_list_type=1,
|
w1_scale=w1_scale,
|
||||||
w1_scale_bias=w1_scale_bias,
|
w2_scale=w2_scale,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
topk_scales=None,
|
w2_scale_bias=w2_scale_bias,
|
||||||
with_quant=True,
|
fusion=True,
|
||||||
fusion=True)
|
))
|
||||||
|
|
||||||
mock_get_forward_context.assert_called()
|
mock_get_forward_context.assert_called()
|
||||||
mock_npu_grouped_matmul.assert_called_once()
|
mock_npu_grouped_matmul.assert_called_once()
|
||||||
|
|||||||
@@ -4,12 +4,21 @@ import torch
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (
|
||||||
AlltoAllCommImpl,
|
AllGatherCommImpl,
|
||||||
MC2CommImpl)
|
AlltoAllCommImpl,
|
||||||
|
MC2CommImpl,
|
||||||
|
)
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEAllGatherCombineMetadata,
|
||||||
|
MoEFusedExpertsInput,
|
||||||
|
MoEPrepareOutput,
|
||||||
|
MoEQuantParams,
|
||||||
|
MoERoutingParams,
|
||||||
|
MoEWeights,
|
||||||
|
)
|
||||||
|
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
|
||||||
TokenDispatchResult)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMoECommMethod(TestBase):
|
class TestMoECommMethod(TestBase):
|
||||||
@@ -45,8 +54,11 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Mock prepare finalize
|
# Mock prepare finalize
|
||||||
mock_pf_instance = MagicMock()
|
mock_pf_instance = MagicMock()
|
||||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||||
torch.randn(4, 2), None, None)
|
hidden_states=torch.randn(4, 8),
|
||||||
|
router_logits=torch.randn(4, 2),
|
||||||
|
mc2_mask=None,
|
||||||
|
padded_hidden_states_shape=None)
|
||||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||||
mock_prepare_finalize.return_value = mock_pf_instance
|
mock_prepare_finalize.return_value = mock_pf_instance
|
||||||
|
|
||||||
@@ -60,8 +72,9 @@ class TestMoECommMethod(TestBase):
|
|||||||
# Test prepare method
|
# Test prepare method
|
||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
prepare_output = comm_impl.prepare(hidden_states, router_logits)
|
||||||
hidden_states, router_logits)
|
h_out = prepare_output.hidden_states
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
@@ -70,7 +83,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out,
|
comm_impl.finalize(h_out,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||||
|
|
||||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@@ -86,10 +99,11 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Mock prepare finalize
|
# Mock prepare finalize
|
||||||
mock_pf_instance = MagicMock()
|
mock_pf_instance = MagicMock()
|
||||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||||
torch.randn(4, 2),
|
hidden_states=torch.randn(4, 8),
|
||||||
torch.tensor([1, 0, 1,
|
router_logits=torch.randn(4, 2),
|
||||||
0]), None)
|
mc2_mask=torch.tensor([1, 0, 1, 0]),
|
||||||
|
padded_hidden_states_shape=None)
|
||||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||||
mock_prepare_finalize.return_value = mock_pf_instance
|
mock_prepare_finalize.return_value = mock_pf_instance
|
||||||
|
|
||||||
@@ -103,8 +117,9 @@ class TestMoECommMethod(TestBase):
|
|||||||
# Test prepare method
|
# Test prepare method
|
||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
prepare_output = comm_impl.prepare(hidden_states, router_logits)
|
||||||
hidden_states, router_logits)
|
h_out = prepare_output.hidden_states
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
@@ -113,7 +128,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out,
|
comm_impl.finalize(h_out,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||||
|
|
||||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@@ -133,8 +148,11 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Mock prepare finalize
|
# Mock prepare finalize
|
||||||
mock_pf_instance = MagicMock()
|
mock_pf_instance = MagicMock()
|
||||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||||
torch.randn(4, 2), None, None)
|
hidden_states=torch.randn(4, 8),
|
||||||
|
router_logits=torch.randn(4, 2),
|
||||||
|
mc2_mask=None,
|
||||||
|
padded_hidden_states_shape=None)
|
||||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||||
mock_prepare_finalize.return_value = mock_pf_instance
|
mock_prepare_finalize.return_value = mock_pf_instance
|
||||||
|
|
||||||
@@ -148,8 +166,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
# Test prepare method
|
# Test prepare method
|
||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
_ = comm_impl.prepare(hidden_states, router_logits)
|
||||||
hidden_states, router_logits)
|
|
||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
@@ -174,19 +191,27 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Mock prepare finalize
|
# Mock prepare finalize
|
||||||
mock_pf_instance = MagicMock()
|
mock_pf_instance = MagicMock()
|
||||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||||
torch.randn(4, 2), None)
|
hidden_states=torch.randn(4, 8),
|
||||||
|
router_logits=torch.randn(4, 2),
|
||||||
|
mc2_mask=None,
|
||||||
|
padded_hidden_states_shape=None)
|
||||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||||
mock_prepare_finalize.return_value = mock_pf_instance
|
mock_prepare_finalize.return_value = mock_pf_instance
|
||||||
|
|
||||||
# Mock token dispatcher
|
# Mock token dispatcher
|
||||||
mock_td_instance = MagicMock()
|
mock_td_instance = MagicMock()
|
||||||
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
dispatch_topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]])
|
||||||
|
mock_td_instance.token_dispatch.return_value = MoETokenDispatchOutput(
|
||||||
hidden_states=torch.randn(6, 8),
|
hidden_states=torch.randn(6, 8),
|
||||||
group_list=torch.tensor([2, 2, 2]),
|
group_list=torch.tensor([2, 2, 2]),
|
||||||
group_list_type=1)
|
group_list_type=1,
|
||||||
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
combine_metadata=MoEAllGatherCombineMetadata(
|
||||||
routed_out=torch.randn(4, 8))
|
topk_weights=dispatch_topk_weights,
|
||||||
|
expanded_row_idx=torch.arange(8, dtype=torch.int32),
|
||||||
|
restore_shape=torch.Size([4, 8]),
|
||||||
|
))
|
||||||
|
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||||
mock_token_dispatcher.return_value = mock_td_instance
|
mock_token_dispatcher.return_value = mock_td_instance
|
||||||
|
|
||||||
# Mock unified_apply_mlp
|
# Mock unified_apply_mlp
|
||||||
@@ -199,8 +224,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
hidden_states = torch.randn(4, 8).contiguous()
|
hidden_states = torch.randn(4, 8).contiguous()
|
||||||
w1 = torch.randn(16, 8).contiguous()
|
w1 = torch.randn(16, 8).contiguous()
|
||||||
w2 = torch.randn(16, 8).contiguous()
|
w2 = torch.randn(16, 8).contiguous()
|
||||||
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
topk_weights = dispatch_topk_weights
|
||||||
[0.6, 0.4]])
|
|
||||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
||||||
|
|
||||||
# Make sure tensors are contiguous and have correct strides
|
# Make sure tensors are contiguous and have correct strides
|
||||||
@@ -208,12 +232,25 @@ class TestMoECommMethod(TestBase):
|
|||||||
w1 = w1.contiguous()
|
w1 = w1.contiguous()
|
||||||
w2 = w2.contiguous()
|
w2 = w2.contiguous()
|
||||||
|
|
||||||
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
result = comm_impl.fused_experts(fused_experts_input=MoEFusedExpertsInput(
|
||||||
w1=[w1],
|
hidden_states=hidden_states,
|
||||||
w2=[w2],
|
topk_weights=topk_weights,
|
||||||
topk_weights=topk_weights,
|
topk_ids=topk_ids,
|
||||||
topk_ids=topk_ids,
|
weights=MoEWeights(
|
||||||
activation="silu")
|
w1=[w1],
|
||||||
|
w2=[w2],
|
||||||
|
),
|
||||||
|
routing=MoERoutingParams(
|
||||||
|
expert_map=None,
|
||||||
|
global_redundant_expert_num=0,
|
||||||
|
mc2_mask=None,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
),
|
||||||
|
activation="silu",
|
||||||
|
need_trans=False,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
quant=MoEQuantParams(),
|
||||||
|
))
|
||||||
|
|
||||||
# Verify result shape
|
# Verify result shape
|
||||||
self.assertEqual(result.routed_out.shape, (4, 8))
|
self.assertEqual(result.routed_out.shape, (4, 8))
|
||||||
@@ -223,6 +260,12 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify unified_apply_mlp was called
|
# Verify unified_apply_mlp was called
|
||||||
mock_unified_apply_mlp.assert_called_once()
|
mock_unified_apply_mlp.assert_called_once()
|
||||||
|
mlp_compute_input = mock_unified_apply_mlp.call_args.kwargs["mlp_compute_input"]
|
||||||
|
self.assertFalse(mlp_compute_input.fusion)
|
||||||
|
self.assertFalse(mlp_compute_input.quant.is_mxfp)
|
||||||
|
|
||||||
# Verify token_combine was called
|
# Verify token_combine was called
|
||||||
mock_td_instance.token_combine.assert_called_once()
|
mock_td_instance.token_combine.assert_called_once_with(
|
||||||
|
hidden_states=mock_unified_apply_mlp.return_value,
|
||||||
|
combine_metadata=mock_td_instance.token_dispatch.return_value.combine_metadata,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list
|
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEMlpComputeInput,
|
||||||
|
MoEQuantParams,
|
||||||
|
MoEWeights,
|
||||||
|
)
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
|
|
||||||
class TestCumsumGroupList(unittest.TestCase):
|
class TestCumsumGroupList(unittest.TestCase):
|
||||||
@@ -14,7 +22,7 @@ class TestCumsumGroupList(unittest.TestCase):
|
|||||||
cls.glist_dict = {
|
cls.glist_dict = {
|
||||||
0: torch.tensor([0, 2, 3, 3]),
|
0: torch.tensor([0, 2, 3, 3]),
|
||||||
1: torch.tensor([0, 2, 1, 0]),
|
1: torch.tensor([0, 2, 1, 0]),
|
||||||
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]])
|
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]),
|
||||||
}
|
}
|
||||||
|
|
||||||
support_combine = [(0, 0), (1, 0), (0, 1)]
|
support_combine = [(0, 0), (1, 0), (0, 1)]
|
||||||
@@ -23,29 +31,101 @@ class TestCumsumGroupList(unittest.TestCase):
|
|||||||
def test_cumsum_group_list_supported_conversion(self):
|
def test_cumsum_group_list_supported_conversion(self):
|
||||||
for src_list_type, dst_list_type in self.support_combine:
|
for src_list_type, dst_list_type in self.support_combine:
|
||||||
with self.subTest(src=src_list_type, dst=dst_list_type):
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
||||||
result = cumsum_group_list(self.glist_dict[src_list_type],
|
result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4)
|
||||||
src_list_type,
|
self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type]))
|
||||||
dst_list_type,
|
|
||||||
expert_num=4)
|
|
||||||
self.assertTrue(
|
|
||||||
torch.equal(result, self.glist_dict[dst_list_type]))
|
|
||||||
|
|
||||||
def test_cumsum_group_list_invalid_type_valueerror(self):
|
def test_cumsum_group_list_invalid_type_valueerror(self):
|
||||||
with self.assertRaises(ValueError) as excinfo:
|
with self.assertRaises(ValueError) as excinfo:
|
||||||
cumsum_group_list(self.glist_dict[0], 4, 0)
|
cumsum_group_list(self.glist_dict[0], 4, 0)
|
||||||
self.assertIn("group_list_type should be in [0, 1, 2], but received",
|
self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception))
|
||||||
str(excinfo.exception))
|
|
||||||
|
|
||||||
def test_cumsum_group_list_unsupported_conversion_notimplementederror(
|
def test_cumsum_group_list_unsupported_conversion_notimplementederror(self):
|
||||||
self):
|
|
||||||
for src_list_type, dst_list_type in self.unsupported_combine:
|
for src_list_type, dst_list_type in self.unsupported_combine:
|
||||||
with self.subTest(src=src_list_type, dst=dst_list_type):
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
||||||
with self.assertRaises(NotImplementedError) as excinfo:
|
with self.assertRaises(NotImplementedError) as excinfo:
|
||||||
cumsum_group_list(self.glist_dict[0], src_list_type,
|
cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type)
|
||||||
dst_list_type)
|
self.assertIn("This feature is under development.", str(excinfo.exception))
|
||||||
self.assertIn("This feature is under development.",
|
|
||||||
str(excinfo.exception))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
class TestUnifiedApplyMlpRequest(unittest.TestCase):
|
||||||
|
def test_request_unquant_path(self):
|
||||||
|
hidden_states = torch.randn(2, 8)
|
||||||
|
expected = torch.randn(2, 8)
|
||||||
|
mlp_compute_input = MoEMlpComputeInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
||||||
|
group_list_type=1,
|
||||||
|
dynamic_scale=None,
|
||||||
|
topk_scales=None,
|
||||||
|
weights=MoEWeights(
|
||||||
|
w1=torch.randn(1, 16, 8),
|
||||||
|
w2=torch.randn(1, 8, 8),
|
||||||
|
w1_bias=torch.randn(1, 16),
|
||||||
|
w2_bias=torch.randn(1, 8),
|
||||||
|
),
|
||||||
|
quant=MoEQuantParams(quant_type=QuantType.NONE),
|
||||||
|
fusion=False,
|
||||||
|
activation="silu",
|
||||||
|
need_trans=False,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant,
|
||||||
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant,
|
||||||
|
):
|
||||||
|
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||||
|
|
||||||
|
self.assertTrue(output is expected)
|
||||||
|
mock_unquant.assert_called_once()
|
||||||
|
self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu")
|
||||||
|
self.assertFalse(mock_unquant.call_args.kwargs["need_trans"])
|
||||||
|
mock_quant.assert_not_called()
|
||||||
|
|
||||||
|
def test_request_quant_path(self):
|
||||||
|
hidden_states = torch.randn(2, 8)
|
||||||
|
expected = torch.randn(2, 8)
|
||||||
|
mlp_compute_input = MoEMlpComputeInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
||||||
|
group_list_type=1,
|
||||||
|
dynamic_scale=torch.randn(2, 1),
|
||||||
|
topk_scales=None,
|
||||||
|
weights=MoEWeights(
|
||||||
|
w1=torch.randn(1, 16, 8),
|
||||||
|
w2=torch.randn(1, 8, 8),
|
||||||
|
w1_scale=[torch.randn(1)],
|
||||||
|
w2_scale=[torch.randn(1)],
|
||||||
|
),
|
||||||
|
quant=MoEQuantParams(
|
||||||
|
quant_type=QuantType.MXFP8,
|
||||||
|
mxfp=MoEMxfpParams(
|
||||||
|
act_quant_type=torch.float8_e4m3fn,
|
||||||
|
weight_quant_type=torch.float8_e4m3fn,
|
||||||
|
use_bf16=False,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fusion=True,
|
||||||
|
activation="silu",
|
||||||
|
need_trans=False,
|
||||||
|
dynamic_eplb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant,
|
||||||
|
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant,
|
||||||
|
):
|
||||||
|
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||||
|
|
||||||
|
self.assertTrue(output is expected)
|
||||||
|
mock_quant.assert_called_once()
|
||||||
|
quant_kwargs = mock_quant.call_args.kwargs
|
||||||
|
self.assertTrue(quant_kwargs["use_mxfp_quant"])
|
||||||
|
self.assertTrue(quant_kwargs["fusion"])
|
||||||
|
self.assertTrue(quant_kwargs["dynamic_eplb"])
|
||||||
|
self.assertFalse(quant_kwargs["use_bf16"])
|
||||||
|
mock_unquant.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
240
tests/ut/ops/test_moe_runtime_args.py
Normal file
240
tests/ut/ops/test_moe_runtime_args.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm_ascend.ops.fused_moe.moe_runtime_args as runtime_args
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEAllGatherCombineMetadata,
|
||||||
|
MoETokenDispatchOutput,
|
||||||
|
MoEWeights,
|
||||||
|
build_fused_experts_input,
|
||||||
|
build_mlp_compute_input,
|
||||||
|
build_token_dispatch_input,
|
||||||
|
)
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoERuntimeArgs(unittest.TestCase):
|
||||||
|
def test_runtime_args_facade_exports_public_contracts_and_builders(self):
|
||||||
|
expected_symbols = [
|
||||||
|
"MoEAllGatherCombineMetadata",
|
||||||
|
"MoEAllToAllCombineMetadata",
|
||||||
|
"MoEFusedExpertsInput",
|
||||||
|
"MoEMC2CombineMetadata",
|
||||||
|
"MoEMlpComputeInput",
|
||||||
|
"MoEPrepareOutput",
|
||||||
|
"MoEQuantParams",
|
||||||
|
"MoERoutingParams",
|
||||||
|
"MoETokenDispatchInput",
|
||||||
|
"MoETokenDispatchOutput",
|
||||||
|
"MoEWeights",
|
||||||
|
"TMoECombineMetadata",
|
||||||
|
"build_fused_experts_input",
|
||||||
|
"build_mlp_compute_input",
|
||||||
|
"build_token_dispatch_input",
|
||||||
|
]
|
||||||
|
|
||||||
|
for symbol in expected_symbols:
|
||||||
|
with self.subTest(symbol=symbol):
|
||||||
|
self.assertTrue(hasattr(runtime_args, symbol))
|
||||||
|
self.assertFalse(hasattr(runtime_args, "MoEMxfpParams"))
|
||||||
|
|
||||||
|
def test_build_fused_experts_input_preserves_runtime_semantics(self):
|
||||||
|
for quant_type in (
|
||||||
|
QuantType.NONE,
|
||||||
|
QuantType.W4A16,
|
||||||
|
QuantType.W4A8,
|
||||||
|
QuantType.W8A8,
|
||||||
|
QuantType.MXFP8,
|
||||||
|
):
|
||||||
|
with self.subTest(quant_type=quant_type):
|
||||||
|
hidden_states = torch.randn(4, 8)
|
||||||
|
topk_weights = torch.randn(4, 2)
|
||||||
|
topk_ids = torch.randint(0, 4, (4, 2), dtype=torch.int32)
|
||||||
|
fused_experts_input = build_fused_experts_input(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
w1=torch.randn(2, 8, 16),
|
||||||
|
w2=torch.randn(2, 16, 8),
|
||||||
|
quant_type=quant_type,
|
||||||
|
dynamic_eplb=True,
|
||||||
|
expert_map=torch.tensor([0, 1, 2, 3], dtype=torch.int32),
|
||||||
|
global_redundant_expert_num=2,
|
||||||
|
mc2_mask=torch.tensor([True, False, True, False]),
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
log2phy=torch.tensor([3, 2, 1, 0], dtype=torch.int32),
|
||||||
|
pertoken_scale=torch.randn(4),
|
||||||
|
activation="gelu",
|
||||||
|
mxfp_act_quant_type=torch.float8_e4m3fn if quant_type == QuantType.MXFP8 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIs(fused_experts_input.hidden_states, hidden_states)
|
||||||
|
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
||||||
|
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
||||||
|
self.assertTrue(fused_experts_input.dynamic_eplb)
|
||||||
|
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||||
|
self.assertEqual(fused_experts_input.routing.global_redundant_expert_num, 2)
|
||||||
|
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||||
|
self.assertEqual(fused_experts_input.quant.quant_type, quant_type)
|
||||||
|
|
||||||
|
def test_build_fused_experts_input_merges_dense_and_quant_weights(self):
|
||||||
|
w1 = torch.randn(2, 8, 16)
|
||||||
|
w2 = torch.randn(2, 16, 8)
|
||||||
|
w1_scale = [torch.randn(1)]
|
||||||
|
w2_scale = [torch.randn(1)]
|
||||||
|
w1_scale_bias = torch.randn(1)
|
||||||
|
w2_scale_bias = torch.randn(1)
|
||||||
|
w1_offset = torch.randn(1)
|
||||||
|
w2_offset = torch.randn(1)
|
||||||
|
|
||||||
|
fused_experts_input = build_fused_experts_input(
|
||||||
|
hidden_states=torch.randn(4, 8),
|
||||||
|
topk_weights=torch.randn(4, 2),
|
||||||
|
topk_ids=torch.randint(0, 4, (4, 2), dtype=torch.int32),
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
quant_type=QuantType.W8A8,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
w1_scale_bias=w1_scale_bias,
|
||||||
|
w2_scale_bias=w2_scale_bias,
|
||||||
|
w1_offset=w1_offset,
|
||||||
|
w2_offset=w2_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(fused_experts_input.weights, MoEWeights)
|
||||||
|
self.assertIs(fused_experts_input.weights.w1, w1)
|
||||||
|
self.assertIs(fused_experts_input.weights.w2, w2)
|
||||||
|
self.assertIs(fused_experts_input.weights.w1_scale, w1_scale)
|
||||||
|
self.assertIs(fused_experts_input.weights.w2_scale, w2_scale)
|
||||||
|
self.assertIs(fused_experts_input.weights.w1_scale_bias, w1_scale_bias)
|
||||||
|
self.assertIs(fused_experts_input.weights.w2_scale_bias, w2_scale_bias)
|
||||||
|
self.assertIs(fused_experts_input.weights.w1_offset, w1_offset)
|
||||||
|
self.assertIs(fused_experts_input.weights.w2_offset, w2_offset)
|
||||||
|
|
||||||
|
def test_build_token_dispatch_input_supports_remapped_topk_ids(self):
|
||||||
|
fused_experts_input = build_fused_experts_input(
|
||||||
|
hidden_states=torch.randn(2, 4),
|
||||||
|
topk_weights=torch.randn(2, 1),
|
||||||
|
topk_ids=torch.tensor([[0], [1]], dtype=torch.int32),
|
||||||
|
w1=torch.randn(1, 4, 8),
|
||||||
|
w2=torch.randn(1, 8, 4),
|
||||||
|
quant_type=QuantType.NONE,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
)
|
||||||
|
routed_topk_ids = torch.tensor([[3], [2]], dtype=torch.int32)
|
||||||
|
|
||||||
|
token_dispatch_input = build_token_dispatch_input(
|
||||||
|
fused_experts_input=fused_experts_input,
|
||||||
|
topk_ids=routed_topk_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIs(token_dispatch_input.hidden_states, fused_experts_input.hidden_states)
|
||||||
|
self.assertIs(token_dispatch_input.topk_weights, fused_experts_input.topk_weights)
|
||||||
|
self.assertIs(token_dispatch_input.routing, fused_experts_input.routing)
|
||||||
|
self.assertIs(token_dispatch_input.quant, fused_experts_input.quant)
|
||||||
|
self.assertIs(token_dispatch_input.topk_ids, routed_topk_ids)
|
||||||
|
|
||||||
|
def test_build_fused_experts_input_requires_primitive_mxfp_params_for_mxfp_quant(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, "primitive MXFP params are required"):
|
||||||
|
build_fused_experts_input(
|
||||||
|
hidden_states=torch.randn(2, 8),
|
||||||
|
topk_weights=torch.randn(2, 2),
|
||||||
|
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||||
|
w1=torch.randn(2, 8, 16),
|
||||||
|
w2=torch.randn(2, 16, 8),
|
||||||
|
quant_type=QuantType.MXFP8,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_build_mlp_compute_input_derives_fusion_and_preserves_mxfp_params(self):
|
||||||
|
fused_experts_input = build_fused_experts_input(
|
||||||
|
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
|
||||||
|
topk_weights=torch.randn(2, 2),
|
||||||
|
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||||
|
w1=torch.randn(2, 8, 16),
|
||||||
|
w2=torch.randn(2, 16, 8),
|
||||||
|
quant_type=QuantType.MXFP8,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
mxfp_act_quant_type=torch.float8_e4m3fn,
|
||||||
|
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
||||||
|
mxfp_scale_dtype=torch.float32,
|
||||||
|
mxfp_per_token_scale_dtype=torch.float16,
|
||||||
|
mxfp_use_bf16=False,
|
||||||
|
w1_scale=[torch.randn(1)],
|
||||||
|
w2_scale=[torch.randn(1)],
|
||||||
|
)
|
||||||
|
token_dispatch_output = MoETokenDispatchOutput(
|
||||||
|
hidden_states=torch.randn(4, 8, dtype=torch.bfloat16),
|
||||||
|
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
||||||
|
group_list_type=1,
|
||||||
|
dynamic_scale=torch.randn(4, 1),
|
||||||
|
combine_metadata=MoEAllGatherCombineMetadata(
|
||||||
|
topk_weights=fused_experts_input.topk_weights,
|
||||||
|
expanded_row_idx=torch.arange(4, dtype=torch.int32),
|
||||||
|
restore_shape=torch.Size([2, 8]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_compute_input = build_mlp_compute_input(
|
||||||
|
fused_experts_input=fused_experts_input,
|
||||||
|
token_dispatch_output=token_dispatch_output,
|
||||||
|
use_fusion_ops=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIs(mlp_compute_input.hidden_states, token_dispatch_output.hidden_states)
|
||||||
|
self.assertIs(mlp_compute_input.weights, fused_experts_input.weights)
|
||||||
|
self.assertIs(mlp_compute_input.weights.w1_scale, fused_experts_input.weights.w1_scale)
|
||||||
|
self.assertIs(mlp_compute_input.weights.w2_scale, fused_experts_input.weights.w2_scale)
|
||||||
|
self.assertTrue(mlp_compute_input.fusion)
|
||||||
|
self.assertTrue(mlp_compute_input.quant.is_mxfp)
|
||||||
|
assert mlp_compute_input.quant.mxfp is not None
|
||||||
|
self.assertEqual(mlp_compute_input.quant.mxfp.scale_dtype, torch.float32)
|
||||||
|
self.assertEqual(mlp_compute_input.quant.mxfp.per_token_scale_dtype, torch.float16)
|
||||||
|
self.assertFalse(mlp_compute_input.quant.mxfp.use_bf16)
|
||||||
|
|
||||||
|
def test_build_fused_experts_input_constructs_internal_mxfp_leaf_from_primitives(self):
|
||||||
|
fused_experts_input = build_fused_experts_input(
|
||||||
|
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
|
||||||
|
topk_weights=torch.randn(2, 2),
|
||||||
|
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||||
|
w1=torch.randn(2, 8, 16),
|
||||||
|
w2=torch.randn(2, 16, 8),
|
||||||
|
quant_type=QuantType.MXFP8,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
mxfp_act_quant_type=torch.float8_e4m3fn,
|
||||||
|
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
||||||
|
mxfp_scale_dtype=torch.float32,
|
||||||
|
mxfp_per_token_scale_dtype=torch.float16,
|
||||||
|
mxfp_use_bf16=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(fused_experts_input.quant.is_mxfp)
|
||||||
|
assert fused_experts_input.quant.mxfp is not None
|
||||||
|
self.assertEqual(fused_experts_input.quant.mxfp.act_quant_type, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(fused_experts_input.quant.mxfp.weight_quant_type, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(fused_experts_input.quant.mxfp.scale_dtype, torch.float32)
|
||||||
|
self.assertEqual(fused_experts_input.quant.mxfp.per_token_scale_dtype, torch.float16)
|
||||||
|
self.assertFalse(fused_experts_input.quant.mxfp.use_bf16)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
@@ -45,18 +45,22 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
|
|
||||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||||
hidden_states, router_logits)
|
h_out = prepare_output.hidden_states
|
||||||
|
r_out = prepare_output.router_logits
|
||||||
|
mask = prepare_output.mc2_mask
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# Check padding and split
|
# Check padding and split
|
||||||
self.assertEqual(h_out.shape[0], 4)
|
self.assertEqual(h_out.shape[0], 4)
|
||||||
self.assertEqual(r_out.shape[0], 4)
|
self.assertEqual(r_out.shape[0], 4)
|
||||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||||
|
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
||||||
|
|
||||||
# Finalize
|
# Finalize
|
||||||
result = layer.finalize(h_out,
|
result = layer.finalize(h_out,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
self.assertEqual(result.shape[0], 3)
|
self.assertEqual(result.shape[0], 3)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@@ -79,14 +83,19 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(4, 8)
|
hidden_states = torch.randn(4, 8)
|
||||||
router_logits = torch.randn(4, 2)
|
router_logits = torch.randn(4, 2)
|
||||||
|
|
||||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
prepare_output = layer.prepare(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
enable_shared_expert_dp=False,
|
enable_shared_expert_dp=False,
|
||||||
replace_allreduce=False)
|
replace_allreduce=False)
|
||||||
|
h_out = prepare_output.hidden_states
|
||||||
|
r_out = prepare_output.router_logits
|
||||||
|
mask = prepare_output.mc2_mask
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# With TP=2, should split into 2 parts
|
# With TP=2, should split into 2 parts
|
||||||
self.assertEqual(h_out.shape[0], 2)
|
self.assertEqual(h_out.shape[0], 2)
|
||||||
|
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
||||||
|
|
||||||
# Mock all_gather behavior
|
# Mock all_gather behavior
|
||||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||||
@@ -101,7 +110,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
final_result = layer.finalize(h_out,
|
final_result = layer.finalize(h_out,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
|
|
||||||
# Should concat back to original size
|
# Should concat back to original size
|
||||||
self.assertEqual(final_result.shape[0], 4)
|
self.assertEqual(final_result.shape[0], 4)
|
||||||
@@ -117,15 +126,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
|
|
||||||
h_out, r_out, _, context_metadata = layer.prepare(
|
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||||
hidden_states, router_logits)
|
h_out = prepare_output.hidden_states
|
||||||
|
r_out = prepare_output.router_logits
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# Pad to tp_size=1, so no change
|
# Pad to tp_size=1, so no change
|
||||||
self.assertEqual(h_out.shape[0], 3)
|
self.assertEqual(h_out.shape[0], 3)
|
||||||
|
self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8]))
|
||||||
|
|
||||||
result = layer.finalize(h_out,
|
result = layer.finalize(h_out,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
self.assertEqual(result.shape[0], 3)
|
self.assertEqual(result.shape[0], 3)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@@ -141,14 +153,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(2, 8)
|
hidden_states = torch.randn(2, 8)
|
||||||
router_logits = torch.randn(2, 2)
|
router_logits = torch.randn(2, 2)
|
||||||
|
|
||||||
h_out, r_out, _, context_metadata = layer.prepare(
|
prepare_output = layer.prepare(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
enable_shared_expert_dp=False,
|
enable_shared_expert_dp=False,
|
||||||
replace_allreduce=False)
|
replace_allreduce=False)
|
||||||
|
h_out = prepare_output.hidden_states
|
||||||
|
r_out = prepare_output.router_logits
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# Split due to TP=2
|
# Split due to TP=2
|
||||||
self.assertEqual(h_out.shape[0], 1)
|
self.assertEqual(h_out.shape[0], 1)
|
||||||
|
self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8]))
|
||||||
|
|
||||||
# Mock all_gather
|
# Mock all_gather
|
||||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||||
@@ -163,7 +179,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
final_result = layer.finalize(h_out,
|
final_result = layer.finalize(h_out,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
|
|
||||||
# Should concat back
|
# Should concat back
|
||||||
self.assertEqual(final_result.shape[0], 2)
|
self.assertEqual(final_result.shape[0], 2)
|
||||||
@@ -200,12 +216,15 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
hidden_states = torch.randn(3, 8)
|
hidden_states = torch.randn(3, 8)
|
||||||
router_logits = torch.randn(3, 2)
|
router_logits = torch.randn(3, 2)
|
||||||
|
|
||||||
h_out, r_out, _, context_metadata = layer.prepare(
|
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||||
hidden_states, router_logits)
|
h_out = prepare_output.hidden_states
|
||||||
|
r_out = prepare_output.router_logits
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# After all-gather with DP=2, should double the batch size
|
# After all-gather with DP=2, should double the batch size
|
||||||
self.assertEqual(h_out.shape[0], 12)
|
self.assertEqual(h_out.shape[0], 12)
|
||||||
self.assertEqual(r_out.shape[0], 12)
|
self.assertEqual(r_out.shape[0], 12)
|
||||||
|
self.assertIsNone(padded_hidden_states_shape)
|
||||||
|
|
||||||
# Finalize with reduce_scatter
|
# Finalize with reduce_scatter
|
||||||
def mock_reduce_scatter_func(tensor, dim):
|
def mock_reduce_scatter_func(tensor, dim):
|
||||||
@@ -215,7 +234,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||||
result = layer.finalize(h_out,
|
result = layer.finalize(h_out,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
context_metadata=context_metadata)
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||||
|
|
||||||
self.assertEqual(result.shape[0], 3)
|
self.assertEqual(result.shape[0], 3)
|
||||||
|
|
||||||
|
|||||||
@@ -17,14 +17,62 @@
|
|||||||
|
|
||||||
from unittest.mock import MagicMock, PropertyMock, patch
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEAllGatherCombineMetadata,
|
||||||
|
MoEAllToAllCombineMetadata,
|
||||||
|
MoEMC2CombineMetadata,
|
||||||
|
MoEQuantParams,
|
||||||
|
MoERoutingParams,
|
||||||
|
MoETokenDispatchInput,
|
||||||
|
)
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||||
AscendDeviceType, TokenDispatcherWithAll2AllV,
|
AscendDeviceType,
|
||||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
TokenDispatcherWithAll2AllV,
|
||||||
|
TokenDispatcherWithAllGather,
|
||||||
|
TokenDispatcherWithMC2,
|
||||||
|
)
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
|
|
||||||
|
def build_token_dispatch_input_fixture(
|
||||||
|
*,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
global_redundant_expert_num: int = 0,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
pertoken_scale: torch.Tensor | None = None,
|
||||||
|
quant_type: QuantType = QuantType.NONE,
|
||||||
|
comm_quant_mode: int | None = None,
|
||||||
|
act_quant_type: torch.dtype | None = None,
|
||||||
|
) -> MoETokenDispatchInput:
|
||||||
|
mxfp_spec = None
|
||||||
|
if quant_type == QuantType.MXFP8:
|
||||||
|
mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type)
|
||||||
|
return MoETokenDispatchInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
routing=MoERoutingParams(
|
||||||
|
expert_map=expert_map,
|
||||||
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
mc2_mask=None,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
),
|
||||||
|
quant=MoEQuantParams(
|
||||||
|
quant_type=quant_type,
|
||||||
|
comm_quant_mode=comm_quant_mode,
|
||||||
|
mxfp=mxfp_spec,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTokenDispatcherWithMC2(TestBase):
|
class TestTokenDispatcherWithMC2(TestBase):
|
||||||
@@ -85,7 +133,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
|||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||||
self.assertFalse(self.dispatcher.with_quant)
|
|
||||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||||
self.assertTrue(self.dispatcher.need_extra_args)
|
self.assertTrue(self.dispatcher.need_extra_args)
|
||||||
|
|
||||||
@@ -94,10 +141,16 @@ class TestTokenDispatcherWithMC2(TestBase):
|
|||||||
topk_ids = torch.randint(0, 8, (10, 1))
|
topk_ids = torch.randint(0, 8, (10, 1))
|
||||||
topk_weights = torch.randn(10, 1)
|
topk_weights = torch.randn(10, 1)
|
||||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||||
mc2_mask = None
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
|
hidden_states=hidden_states,
|
||||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
topk_weights=topk_weights,
|
||||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
|
topk_ids=topk_ids,
|
||||||
|
expert_map=expert_map,
|
||||||
|
global_redundant_expert_num=0,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
pertoken_scale=None,
|
||||||
|
)
|
||||||
|
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input)
|
||||||
self.assertIn("x", kwargs)
|
self.assertIn("x", kwargs)
|
||||||
self.assertIn("expert_ids", kwargs)
|
self.assertIn("expert_ids", kwargs)
|
||||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||||
@@ -111,39 +164,42 @@ class TestTokenDispatcherWithMC2(TestBase):
|
|||||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||||
return_value=(torch.randn(10, 128), ) * 5 +
|
return_value=(torch.randn(10, 128), ) * 5 +
|
||||||
(None, None)) as mock_dispatch:
|
(None, None)) as mock_dispatch:
|
||||||
output = self.dispatcher.token_dispatch(hidden_states,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_weights, topk_ids,
|
hidden_states=hidden_states,
|
||||||
expert_map)
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
mock_dispatch.assert_called_once()
|
mock_dispatch.assert_called_once()
|
||||||
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
||||||
|
self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata)
|
||||||
|
|
||||||
def test_get_combine_mc_kwargs_with_quant(self):
|
def test_get_combine_mc_kwargs_with_quant(self):
|
||||||
self.dispatcher.with_quant = True
|
|
||||||
hidden_states = torch.randn(10, 128)
|
hidden_states = torch.randn(10, 128)
|
||||||
topk_ids = torch.randint(0, 8, (10, 1))
|
topk_ids = torch.randint(0, 8, (10, 1))
|
||||||
topk_weights = torch.randn(10, 1)
|
topk_weights = torch.randn(10, 1)
|
||||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||||
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||||
mc2_mask = None
|
|
||||||
assist_info_for_combine = torch.arange(10)
|
assist_info_for_combine = torch.arange(10)
|
||||||
|
|
||||||
context_metadata = {
|
combine_metadata = MoEMC2CombineMetadata(
|
||||||
"topk_ids": topk_ids,
|
topk_ids=topk_ids,
|
||||||
"topk_weights": topk_weights,
|
topk_weights=topk_weights,
|
||||||
"expert_map": expert_map,
|
expert_map=expert_map,
|
||||||
"ep_recv_counts": ep_recv_counts,
|
ep_recv_counts=ep_recv_counts,
|
||||||
"mc2_mask": mc2_mask,
|
tp_recv_counts=tp_recv_counts,
|
||||||
"assist_info_for_combine": assist_info_for_combine,
|
assist_info_for_combine=assist_info_for_combine,
|
||||||
"expand_scales": None,
|
expand_scales=None,
|
||||||
"tp_recv_counts": tp_recv_counts
|
dispatch_with_quant=True,
|
||||||
}
|
)
|
||||||
|
|
||||||
self.dispatcher.need_extra_args = True
|
self.dispatcher.need_extra_args = True
|
||||||
self.dispatcher.enable_dispatch_v2 = True
|
self.dispatcher.enable_dispatch_v2 = True
|
||||||
self.dispatcher.moe_expert_num = len(expert_map)
|
self.dispatcher.moe_expert_num = len(expert_map)
|
||||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
||||||
context_metadata)
|
combine_metadata)
|
||||||
self.assertIn("tp_send_counts", kwargs)
|
self.assertIn("tp_send_counts", kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -188,14 +244,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||||
|
|
||||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_ids, None)
|
hidden_states=hidden_states,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
)
|
||||||
|
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
# Verify npu_moe_init_routing is called
|
# Verify npu_moe_init_routing is called
|
||||||
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
||||||
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
||||||
|
|
||||||
self.assertEqual(results.group_list_type, 1)
|
self.assertEqual(results.group_list_type, 1)
|
||||||
|
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
@@ -205,14 +266,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||||
|
|
||||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_ids, None)
|
hidden_states=hidden_states,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
)
|
||||||
|
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
# Verify npu_moe_init_routing is called
|
# Verify npu_moe_init_routing is called
|
||||||
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
||||||
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
||||||
|
|
||||||
self.assertEqual(results.group_list_type, 1)
|
self.assertEqual(results.group_list_type, 1)
|
||||||
|
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
@@ -230,9 +296,12 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||||
|
|
||||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_weights, topk_ids,
|
hidden_states=hidden_states,
|
||||||
None)
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
)
|
||||||
|
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
self.assertEqual(results.group_list_type, 1)
|
self.assertEqual(results.group_list_type, 1)
|
||||||
|
|
||||||
@@ -252,11 +321,13 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||||
|
|
||||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_weights,
|
hidden_states=hidden_states,
|
||||||
topk_ids,
|
topk_weights=topk_weights,
|
||||||
None,
|
topk_ids=topk_ids,
|
||||||
with_quant=True)
|
quant_type=QuantType.W8A8,
|
||||||
|
)
|
||||||
|
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
self.assertIsNotNone(results.hidden_states)
|
self.assertIsNotNone(results.hidden_states)
|
||||||
self.assertIsNotNone(results.group_list)
|
self.assertIsNotNone(results.group_list)
|
||||||
@@ -267,40 +338,43 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
def test_token_combine_with_expert_map(self):
|
def test_token_combine_with_expert_map(self):
|
||||||
hidden_states = torch.randn(6, 128)
|
hidden_states = torch.randn(6, 128)
|
||||||
context_metadata = {
|
combine_metadata = MoEAllGatherCombineMetadata(
|
||||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||||
}
|
restore_shape=torch.Size([6, 128]),
|
||||||
self.dispatcher.original_shape = (6, 128)
|
)
|
||||||
final_hidden_states = self.dispatcher.token_combine(
|
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||||
hidden_states, context_metadata).routed_out
|
|
||||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
def test_token_combine_without_expert_map(self):
|
def test_token_combine_without_expert_map(self):
|
||||||
hidden_states = torch.randn(6, 128)
|
hidden_states = torch.randn(6, 128)
|
||||||
context_metadata = {
|
combine_metadata = MoEAllGatherCombineMetadata(
|
||||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||||
}
|
restore_shape=torch.Size([6, 128]),
|
||||||
self.dispatcher.original_shape = (6, 128)
|
)
|
||||||
final_hidden_states = self.dispatcher.token_combine(
|
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||||
hidden_states, context_metadata).routed_out
|
|
||||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
def test_token_dispatch_with_router_weight(self):
|
def test_token_dispatch_with_router_weight(self):
|
||||||
self.dispatcher.apply_router_weight_on_input = True
|
|
||||||
hidden_states = torch.randn(3, 128)
|
hidden_states = torch.randn(3, 128)
|
||||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||||
topk_ids = torch.tensor([[0], [1], [2]])
|
topk_ids = torch.tensor([[0], [1], [2]])
|
||||||
|
|
||||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_ids, None)
|
hidden_states=hidden_states,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
)
|
||||||
|
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
self.assertEqual(results.hidden_states.shape, (6, 128))
|
self.assertEqual(results.hidden_states.shape, (6, 128))
|
||||||
|
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||||
|
|
||||||
|
|
||||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||||
@@ -408,35 +482,39 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
[0, 1], dtype=torch.int32)
|
[0, 1], dtype=torch.int32)
|
||||||
self.dispatcher.local_expert_indices = [0, 1]
|
self.dispatcher.local_expert_indices = [0, 1]
|
||||||
|
|
||||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_weights=topk_weights,
|
hidden_states=hidden_states,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
expert_map=expert_map)
|
topk_ids=topk_ids,
|
||||||
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
self.assertIsNotNone(result.hidden_states)
|
self.assertIsNotNone(result.hidden_states)
|
||||||
self.assertIsNotNone(result.group_list)
|
self.assertIsNotNone(result.group_list)
|
||||||
self.assertEqual(result.group_list_type, 1)
|
self.assertEqual(result.group_list_type, 1)
|
||||||
|
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
def test_token_combine(self):
|
def test_token_combine(self):
|
||||||
hidden_states = torch.randn(16, 16)
|
hidden_states = torch.randn(16, 16)
|
||||||
context_metadata = {
|
combine_metadata = MoEAllToAllCombineMetadata(
|
||||||
"input_splits": [4, 4],
|
input_splits=np.array([4, 4]),
|
||||||
"output_splits": [4, 4],
|
output_splits=np.array([4, 4]),
|
||||||
"topk_weights": torch.rand(8, 4),
|
topk_weights=torch.rand(8, 4),
|
||||||
"reversed_local_input_permutation_mapping": torch.arange(8),
|
reversed_local_input_permutation_mapping=torch.arange(8),
|
||||||
"reversed_global_input_permutation_mapping": torch.arange(16),
|
reversed_global_input_permutation_mapping=torch.arange(16),
|
||||||
}
|
hidden_shape=torch.Size([8, 16]),
|
||||||
self.dispatcher.hidden_shape = (8, 16)
|
hidden_shape_before_permute=torch.Size([8, 16]),
|
||||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
)
|
||||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||||
[0, 1], dtype=torch.int32)
|
[0, 1], dtype=torch.int32)
|
||||||
self.dispatcher.local_expert_indices = [0, 1]
|
self.dispatcher.local_expert_indices = [0, 1]
|
||||||
|
|
||||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
output = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||||
self.assertIsNotNone(output)
|
self.assertIsNotNone(output)
|
||||||
self.assertEqual(output.routed_out.shape, (8, 16))
|
self.assertEqual(output.shape, (8, 16))
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
@@ -454,16 +532,20 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
[0, 1], dtype=torch.int32)
|
[0, 1], dtype=torch.int32)
|
||||||
self.dispatcher.local_expert_indices = [0, 1]
|
self.dispatcher.local_expert_indices = [0, 1]
|
||||||
|
|
||||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_weights=topk_weights,
|
hidden_states=hidden_states,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
expert_map=expert_map,
|
topk_ids=topk_ids,
|
||||||
with_quant=True)
|
expert_map=expert_map,
|
||||||
|
quant_type=QuantType.W8A8,
|
||||||
|
)
|
||||||
|
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
self.assertIsNotNone(result.hidden_states)
|
self.assertIsNotNone(result.hidden_states)
|
||||||
self.assertIsNotNone(result.group_list)
|
self.assertIsNotNone(result.group_list)
|
||||||
self.assertIsNotNone(result.dynamic_scale)
|
self.assertIsNotNone(result.dynamic_scale)
|
||||||
self.assertEqual(result.group_list_type, 1)
|
self.assertEqual(result.group_list_type, 1)
|
||||||
|
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||||
@@ -484,14 +566,16 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
[0, 1], dtype=torch.int32)
|
[0, 1], dtype=torch.int32)
|
||||||
self.dispatcher.local_expert_indices = [0, 1]
|
self.dispatcher.local_expert_indices = [0, 1]
|
||||||
|
|
||||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||||
topk_weights=topk_weights,
|
hidden_states=hidden_states,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
expert_map=expert_map,
|
topk_ids=topk_ids,
|
||||||
with_quant=True)
|
expert_map=expert_map,
|
||||||
|
quant_type=QuantType.W8A8,
|
||||||
|
)
|
||||||
|
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
self.assertIsNotNone(result.hidden_states)
|
self.assertIsNotNone(result.hidden_states)
|
||||||
self.assertIsNotNone(result.group_list)
|
self.assertIsNotNone(result.group_list)
|
||||||
self.assertIsNotNone(result.dynamic_scale)
|
self.assertIsNotNone(result.dynamic_scale)
|
||||||
self.assertEqual(result.group_list_type, 1)
|
self.assertEqual(result.group_list_type, 1)
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ from unittest.mock import Mock, patch
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
pack_to_int32,
|
from vllm_ascend.quantization.methods.w4a16 import AscendW4A16FusedMoEMethod, pack_to_int32, unpack_from_int32
|
||||||
unpack_from_int32)
|
|
||||||
|
|
||||||
|
|
||||||
class TestUnpackFromInt32(TestBase):
|
class TestUnpackFromInt32(TestBase):
|
||||||
@@ -268,3 +267,41 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
|
|||||||
torch.equal(layer.w13_weight_packed.data, original_w13_data))
|
torch.equal(layer.w13_weight_packed.data, original_w13_data))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.equal(layer.w2_weight_packed.data, original_w2_data))
|
torch.equal(layer.w2_weight_packed.data, original_w2_data))
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.methods.w4a16._EXTRA_CTX")
|
||||||
|
@patch("vllm_ascend.quantization.methods.w4a16.select_experts")
|
||||||
|
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
|
||||||
|
tokens = 3
|
||||||
|
hidden_size = self.output_size
|
||||||
|
layer = self.build_layer()
|
||||||
|
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||||
|
router_logits = torch.randn(tokens, self.experts, dtype=torch.float32)
|
||||||
|
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
|
||||||
|
topk_ids = torch.randint(0, self.experts, (tokens, 2), dtype=torch.int64)
|
||||||
|
mc2_mask = torch.tensor([1, 0, 1], dtype=torch.bool)
|
||||||
|
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
|
||||||
|
|
||||||
|
mock_select_experts.return_value = (topk_weights, topk_ids)
|
||||||
|
mock_comm = Mock()
|
||||||
|
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||||
|
mock_extra_ctx.moe_comm_method = mock_comm
|
||||||
|
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
|
||||||
|
self.quant_method.apply(
|
||||||
|
layer=layer,
|
||||||
|
x=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=2,
|
||||||
|
renormalize=True,
|
||||||
|
global_num_experts=self.experts,
|
||||||
|
activation="gelu",
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
|
||||||
|
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||||
|
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||||
|
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
|
||||||
|
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from unittest.mock import Mock, patch
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.quantization.methods.w8a8_dynamic import \
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
AscendW8A8DynamicFusedMoEMethod
|
from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod
|
||||||
|
|
||||||
|
|
||||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||||
@@ -32,8 +32,9 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
|||||||
mock_ep_group = Mock()
|
mock_ep_group = Mock()
|
||||||
mock_get_ep_group.return_value = mock_ep_group
|
mock_get_ep_group.return_value = mock_ep_group
|
||||||
mock_ascend_config = Mock()
|
mock_ascend_config = Mock()
|
||||||
|
|
||||||
mock_ascend_config.enable_chunked_prefill = False
|
mock_ascend_config.enable_chunked_prefill = False
|
||||||
|
mock_ascend_config.multistream_overlap_gate = False
|
||||||
|
mock_ascend_config.eplb_config = Mock(dynamic_eplb=False)
|
||||||
mock_get_ascend_config.return_value = mock_ascend_config
|
mock_get_ascend_config.return_value = mock_ascend_config
|
||||||
mock_mc2_group = Mock(device_group=0)
|
mock_mc2_group = Mock(device_group=0)
|
||||||
mock_get_mc2_group.return_value = mock_mc2_group
|
mock_get_mc2_group.return_value = mock_mc2_group
|
||||||
@@ -104,3 +105,125 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
|||||||
new_layer = self.build_layer()
|
new_layer = self.build_layer()
|
||||||
self.quant_method.process_weights_after_loading(new_layer)
|
self.quant_method.process_weights_after_loading(new_layer)
|
||||||
mock_npu_format_cast.assert_called()
|
mock_npu_format_cast.assert_called()
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
|
||||||
|
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
|
||||||
|
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
|
||||||
|
tokens = 4
|
||||||
|
hidden_size = self.hidden_size
|
||||||
|
layer = torch.nn.Module()
|
||||||
|
layer.w13_weight = torch.randint(
|
||||||
|
-8,
|
||||||
|
8,
|
||||||
|
(self.num_experts, 2 * self.intermediate_size, hidden_size),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
layer.w2_weight = torch.randint(
|
||||||
|
-8,
|
||||||
|
8,
|
||||||
|
(self.num_experts, hidden_size, self.intermediate_size),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
|
||||||
|
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||||
|
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
|
||||||
|
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
|
||||||
|
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
|
||||||
|
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
|
||||||
|
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
|
||||||
|
|
||||||
|
mock_select_experts.return_value = (topk_weights, topk_ids)
|
||||||
|
mock_comm = Mock()
|
||||||
|
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||||
|
mock_extra_ctx.moe_comm_method = mock_comm
|
||||||
|
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
self.quant_method.multistream_overlap_gate = False
|
||||||
|
self.quant_method.in_dtype = torch.float32
|
||||||
|
|
||||||
|
self.quant_method.apply(
|
||||||
|
layer=layer,
|
||||||
|
x=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=2,
|
||||||
|
renormalize=True,
|
||||||
|
global_num_experts=self.num_experts,
|
||||||
|
activation="gelu",
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
|
||||||
|
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||||
|
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||||
|
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
|
||||||
|
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
|
||||||
|
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
||||||
|
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_flash_common3_context")
|
||||||
|
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
|
||||||
|
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
|
||||||
|
def test_apply_overlap_gate_uses_fc3_context(
|
||||||
|
self,
|
||||||
|
mock_select_experts,
|
||||||
|
mock_extra_ctx,
|
||||||
|
mock_get_flash_common3_context,
|
||||||
|
):
|
||||||
|
tokens = 4
|
||||||
|
hidden_size = self.hidden_size
|
||||||
|
layer = torch.nn.Module()
|
||||||
|
layer.w13_weight = torch.randint(
|
||||||
|
-8,
|
||||||
|
8,
|
||||||
|
(self.num_experts, 2 * self.intermediate_size, hidden_size),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
layer.w2_weight = torch.randint(
|
||||||
|
-8,
|
||||||
|
8,
|
||||||
|
(self.num_experts, hidden_size, self.intermediate_size),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
|
||||||
|
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||||
|
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
|
||||||
|
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
|
||||||
|
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
|
||||||
|
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
|
||||||
|
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
|
||||||
|
|
||||||
|
self.quant_method.multistream_overlap_gate = True
|
||||||
|
self.quant_method.in_dtype = torch.float32
|
||||||
|
mock_get_flash_common3_context.return_value = Mock(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||||
|
|
||||||
|
mock_comm = Mock()
|
||||||
|
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||||
|
mock_extra_ctx.moe_comm_method = mock_comm
|
||||||
|
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
|
||||||
|
self.quant_method.apply(
|
||||||
|
layer=layer,
|
||||||
|
x=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=2,
|
||||||
|
renormalize=True,
|
||||||
|
global_num_experts=self.num_experts,
|
||||||
|
activation="gelu",
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_select_experts.assert_not_called()
|
||||||
|
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
|
||||||
|
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||||
|
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||||
|
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
|
||||||
|
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
|
||||||
|
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
||||||
|
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
|||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
|
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
from .experts_selector import select_experts
|
from .experts_selector import select_experts
|
||||||
from .moe_comm_method import AllGatherCommImpl310
|
from .moe_comm_method import AllGatherCommImpl310
|
||||||
@@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
|
|||||||
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1=layer.w13_weight,
|
hidden_states=x,
|
||||||
w2=layer.w2_weight,
|
topk_weights=topk_weights,
|
||||||
topk_weights=topk_weights,
|
topk_ids=topk_ids,
|
||||||
topk_ids=topk_ids,
|
w1=layer.w13_weight,
|
||||||
expert_map=expert_map,
|
w2=layer.w2_weight,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
quant_type=QuantType.NONE,
|
||||||
|
dynamic_eplb=False,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||||
final_hidden_states += zero_expert_result
|
final_hidden_states += zero_expert_result
|
||||||
@@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
||||||
|
|
||||||
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
||||||
)
|
)
|
||||||
|
hidden_states = prepare_output.hidden_states
|
||||||
|
router_logits = prepare_output.router_logits
|
||||||
|
pertoken_scale = prepare_output.pertoken_scale
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||||
@@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE):
|
|||||||
global_num_experts=self.global_num_experts,
|
global_num_experts=self.global_num_experts,
|
||||||
expert_map=self.local_expert_map,
|
expert_map=self.local_expert_map,
|
||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||||
hidden_states=fused_experts_results.routed_out,
|
hidden_states=fused_experts_results.routed_out,
|
||||||
reduce_results=self.reduce_results,
|
reduce_results=self.reduce_results,
|
||||||
context_metadata=context_metadata,
|
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
return routed_out
|
return routed_out
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
|
||||||
|
|
||||||
from .moe_mlp import unified_apply_mlp
|
from .moe_mlp import unified_apply_mlp
|
||||||
from .token_dispatcher import TokenDispatcherWithAllGather310
|
from .token_dispatcher import TokenDispatcherWithAllGather310
|
||||||
@@ -35,52 +35,12 @@ class AllGatherCommImpl310(AllGatherCommImpl):
|
|||||||
to handle the token-to-expert mapping and communication efficiently.
|
to handle the token-to-expert mapping and communication efficiently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fused_experts( # type: ignore[override]
|
def __init__(self, moe_config):
|
||||||
self,
|
super().__init__(moe_config)
|
||||||
hidden_states: torch.Tensor,
|
self.use_fusion_ops = False
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
w1_scale: torch.Tensor | None = None,
|
|
||||||
w2_scale: torch.Tensor | None = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
) -> FusedExpertsResult:
|
|
||||||
# This method is overridden to use the 310p-specific unified_apply_mlp
|
|
||||||
# which provides optimized MLP computation for the 310p platform
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
|
||||||
assert moe_comm_method is not None, "Missing communication context"
|
|
||||||
|
|
||||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||||
hidden_states=hidden_states,
|
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
mlp_output = unified_apply_mlp(
|
|
||||||
hidden_states=dispatch_results.hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
group_list=dispatch_results.group_list,
|
|
||||||
group_list_type=dispatch_results.group_list_type,
|
|
||||||
with_quant=use_int8_w8a8,
|
|
||||||
)
|
|
||||||
|
|
||||||
combine_results = self.token_dispatcher.token_combine(
|
|
||||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
return FusedExpertsResult(
|
|
||||||
routed_out=combine_results.routed_out,
|
|
||||||
group_list_type=dispatch_results.group_list_type,
|
|
||||||
expert_tokens=dispatch_results.group_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_token_dispatcher(self):
|
def _get_token_dispatcher(self):
|
||||||
return TokenDispatcherWithAllGather310(
|
return TokenDispatcherWithAllGather310(
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
|
||||||
|
|
||||||
|
|
||||||
def quant_apply_mlp(
|
def quant_apply_mlp(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -66,17 +68,20 @@ def unquant_apply_mlp(
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def unified_apply_mlp(
|
def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||||
hidden_states: torch.Tensor,
|
hidden_states = mlp_compute_input.hidden_states
|
||||||
w1: torch.Tensor,
|
w1 = mlp_compute_input.weights.w1
|
||||||
w2: torch.Tensor,
|
w2 = mlp_compute_input.weights.w2
|
||||||
group_list: torch.Tensor,
|
w1_scale = mlp_compute_input.weights.w1_scale
|
||||||
w1_scale: torch.Tensor | None = None,
|
w2_scale = mlp_compute_input.weights.w2_scale
|
||||||
w2_scale: torch.Tensor | None = None,
|
group_list = mlp_compute_input.group_list
|
||||||
group_list_type: int = 1,
|
group_list_type = mlp_compute_input.group_list_type
|
||||||
with_quant: bool = False,
|
assert isinstance(w1, torch.Tensor)
|
||||||
) -> torch.Tensor:
|
assert isinstance(w2, torch.Tensor)
|
||||||
if with_quant:
|
|
||||||
|
if mlp_compute_input.quant.is_quant:
|
||||||
|
assert isinstance(w1_scale, torch.Tensor)
|
||||||
|
assert isinstance(w2_scale, torch.Tensor)
|
||||||
assert w1_scale is not None and w2_scale is not None
|
assert w1_scale is not None and w2_scale is not None
|
||||||
return quant_apply_mlp(
|
return quant_apply_mlp(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -87,7 +92,11 @@ def unified_apply_mlp(
|
|||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return unquant_apply_mlp(
|
return unquant_apply_mlp(
|
||||||
hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type
|
hidden_states=hidden_states,
|
||||||
)
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
group_list=group_list,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,26 +25,27 @@
|
|||||||
import torch
|
import torch
|
||||||
from vllm.distributed.parallel_state import get_ep_group
|
from vllm.distributed.parallel_state import get_ep_group
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput
|
||||||
|
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather
|
||||||
|
|
||||||
|
|
||||||
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def token_dispatch( # type: ignore[override]
|
def token_dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
):
|
):
|
||||||
self.original_shape = hidden_states.shape
|
hidden_states = token_dispatch_input.hidden_states
|
||||||
|
topk_weights = token_dispatch_input.topk_weights
|
||||||
|
topk_ids = token_dispatch_input.topk_ids
|
||||||
|
expert_map = token_dispatch_input.routing.expert_map
|
||||||
|
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
|
||||||
|
restore_shape = hidden_states.shape
|
||||||
|
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
num_tokens = hidden_states.shape[:-1].numel()
|
||||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
if apply_router_weight_on_input:
|
||||||
if self.apply_router_weight_on_input:
|
|
||||||
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
_, topk = topk_weights.shape
|
_, topk = topk_weights.shape
|
||||||
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
@@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
|||||||
)
|
)
|
||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 1 # `count` mode
|
group_list_type = 1 # `count` mode
|
||||||
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
|
|
||||||
|
|
||||||
return TokenDispatchResult(
|
return MoETokenDispatchOutput(
|
||||||
hidden_states=sorted_hidden_states,
|
hidden_states=sorted_hidden_states,
|
||||||
group_list=expert_tokens,
|
group_list=expert_tokens,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
context_metadata=context_metadata,
|
combine_metadata=MoEAllGatherCombineMetadata(
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
expanded_row_idx=expanded_row_idx,
|
||||||
|
restore_shape=restore_shape,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):
|
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from vllm.distributed import get_ep_group
|
|||||||
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
|
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
|
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
|
||||||
|
|
||||||
from .registry import register_scheme
|
from .registry import register_scheme
|
||||||
@@ -95,7 +96,9 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
|
|||||||
log2phy: torch.Tensor | None = None,
|
log2phy: torch.Tensor | None = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
pertoken_scale: Any | None = None,
|
pertoken_scale: Any | None = None,
|
||||||
**kwargs,
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||||
@@ -128,15 +131,19 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
|
|||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
|
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1=layer.w13_weight,
|
hidden_states=x,
|
||||||
w1_scale=layer.w13_weight_scale,
|
topk_weights=topk_weights,
|
||||||
w2=layer.w2_weight,
|
topk_ids=topk_ids,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w1=layer.w13_weight,
|
||||||
topk_weights=topk_weights,
|
w2=layer.w2_weight,
|
||||||
topk_ids=topk_ids,
|
quant_type=self.quant_type,
|
||||||
expert_map=expert_map,
|
dynamic_eplb=False,
|
||||||
use_int8_w8a8=True,
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||||
final_hidden_states += zero_expert_result
|
final_hidden_states += zero_expert_result
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
|||||||
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
ACL_FORMAT_FRACTAL_NZ,
|
ACL_FORMAT_FRACTAL_NZ,
|
||||||
enable_sp,
|
enable_sp,
|
||||||
@@ -113,7 +114,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
log2phy: torch.Tensor = None,
|
log2phy: torch.Tensor = None,
|
||||||
**kwargs,
|
global_redundant_expert_num: int = 0,
|
||||||
|
pertoken_scale: torch.Tensor | None = None,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||||
@@ -167,7 +170,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
# (due to signature constraints), we are forced to use a placeholder empty tensor.
|
# (due to signature constraints), we are forced to use a placeholder empty tensor.
|
||||||
# This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor]
|
# This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor]
|
||||||
# or None for scales in non-quantized scenarios.
|
# or None for scales in non-quantized scenarios.
|
||||||
if get_forward_context().moe_comm_type == MoECommType.FUSED_MC2:
|
if _EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2:
|
||||||
w1 = [layer.w13_weight]
|
w1 = [layer.w13_weight]
|
||||||
w1_scale = [torch.tensor([], dtype=torch.int64)]
|
w1_scale = [torch.tensor([], dtype=torch.int64)]
|
||||||
w2 = [layer.w2_weight]
|
w2 = [layer.w2_weight]
|
||||||
@@ -179,21 +182,26 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
w2_scale = None
|
w2_scale = None
|
||||||
|
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1=w1,
|
hidden_states=x,
|
||||||
w2=w2,
|
topk_weights=topk_weights,
|
||||||
w1_scale=w1_scale,
|
topk_ids=topk_ids,
|
||||||
w2_scale=w2_scale,
|
w1=w1,
|
||||||
w1_bias=layer.w13_bias if self.moe.has_bias else None,
|
w2=w2,
|
||||||
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
w1_bias=layer.w13_bias if self.moe.has_bias else None,
|
||||||
activation=activation,
|
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
||||||
topk_weights=topk_weights,
|
quant_type=QuantType.NONE,
|
||||||
topk_ids=topk_ids,
|
dynamic_eplb=self.dynamic_eplb,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
dynamic_eplb=self.dynamic_eplb,
|
mc2_mask=mc2_mask,
|
||||||
log2phy=log2phy,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
mc2_mask=kwargs.get("mc2_mask"),
|
log2phy=log2phy,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
activation=activation,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||||
final_hidden_states += zero_expert_result
|
final_hidden_states += zero_expert_result
|
||||||
@@ -474,23 +482,23 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||||
|
|
||||||
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
|
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
|
||||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||||
quant_type=self.quant_type,
|
quant_type=self.quant_type,
|
||||||
)
|
)
|
||||||
|
hidden_states = prepare_output.hidden_states
|
||||||
|
router_logits = prepare_output.router_logits
|
||||||
|
mc2_mask = prepare_output.mc2_mask
|
||||||
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||||
|
pertoken_scale = prepare_output.pertoken_scale
|
||||||
|
|
||||||
# Make sure the default stream waits for the gate stream to finish.
|
# Make sure the default stream waits for the gate stream to finish.
|
||||||
if self.multistream_overlap_gate:
|
if self.multistream_overlap_gate:
|
||||||
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
|
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
|
||||||
|
|
||||||
if isinstance(hidden_states, tuple):
|
|
||||||
hidden_states, pertoken_scale = hidden_states
|
|
||||||
else:
|
|
||||||
pertoken_scale = None
|
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
@@ -538,7 +546,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||||
hidden_states=fused_experts_results.routed_out,
|
hidden_states=fused_experts_results.routed_out,
|
||||||
reduce_results=self.reduce_results,
|
reduce_results=self.reduce_results,
|
||||||
context_metadata=context_metadata,
|
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_with_event:
|
if return_with_event:
|
||||||
|
|||||||
@@ -24,6 +24,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEFusedExpertsInput,
|
||||||
|
MoEMlpComputeInput,
|
||||||
|
MoEPrepareOutput,
|
||||||
|
build_mlp_compute_input,
|
||||||
|
build_token_dispatch_input,
|
||||||
|
)
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||||
PrepareAndFinalize,
|
PrepareAndFinalize,
|
||||||
PrepareAndFinalizeWithAll2All,
|
PrepareAndFinalizeWithAll2All,
|
||||||
@@ -36,8 +43,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
|||||||
TokenDispatcherWithAllGather,
|
TokenDispatcherWithAllGather,
|
||||||
TokenDispatcherWithMC2,
|
TokenDispatcherWithMC2,
|
||||||
)
|
)
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params
|
|
||||||
|
|
||||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||||
|
|
||||||
@@ -90,131 +96,70 @@ class MoECommMethod(ABC):
|
|||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
quant_type: QuantType = QuantType.NONE,
|
quant_type: QuantType = QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
return self.prepare_finalize.prepare(
|
||||||
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
enable_shared_expert_dp,
|
||||||
|
replace_allreduce,
|
||||||
|
quant_type,
|
||||||
)
|
)
|
||||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool,
|
||||||
|
padded_hidden_states_shape: torch.Size | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
|
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, padded_hidden_states_shape)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def fused_experts(
|
def fused_experts(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
fused_experts_input: MoEFusedExpertsInput,
|
||||||
w1: torch.Tensor | list[torch.Tensor],
|
|
||||||
w2: torch.Tensor | list[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
activation: str = "silu",
|
|
||||||
w1_bias: torch.Tensor = None,
|
|
||||||
w2_bias: torch.Tensor = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int4_w4a8: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
w1_scale: list[torch.Tensor] | None = None,
|
|
||||||
w2_scale: list[torch.Tensor] | None = None,
|
|
||||||
w1_scale_bias: torch.Tensor = None,
|
|
||||||
w2_scale_bias: torch.Tensor = None,
|
|
||||||
w1_offset: torch.Tensor | None = None,
|
|
||||||
w2_offset: torch.Tensor | None = None,
|
|
||||||
# For load balance
|
|
||||||
log2phy: torch.Tensor = None,
|
|
||||||
need_trans: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
mc2_mask: torch.Tensor = None,
|
|
||||||
pertoken_scale: torch.Tensor | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
# Check constraints
|
# Check constraints
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
assert fused_experts_input.hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||||
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
assert moe_comm_method is not None, "Missing communication context"
|
assert moe_comm_method is not None, "Missing communication context"
|
||||||
|
|
||||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||||
# Apply log2phy if needed
|
routed_topk_ids = fused_experts_input.topk_ids
|
||||||
if log2phy is not None:
|
if fused_experts_input.routing.log2phy is not None:
|
||||||
topk_ids = log2phy[topk_ids]
|
routed_topk_ids = fused_experts_input.routing.log2phy[routed_topk_ids]
|
||||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced
|
|
||||||
# by different quantization modes will be consolidated into a dataclass in a follow-up.
|
token_dispatch_input = build_token_dispatch_input(
|
||||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
fused_experts_input=fused_experts_input,
|
||||||
dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant
|
topk_ids=routed_topk_ids,
|
||||||
act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params(
|
)
|
||||||
**kwargs
|
token_dispatch_output = self.token_dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||||
|
|
||||||
|
mlp_compute_input = build_mlp_compute_input(
|
||||||
|
fused_experts_input=fused_experts_input,
|
||||||
|
token_dispatch_output=token_dispatch_output,
|
||||||
|
use_fusion_ops=self.use_fusion_ops,
|
||||||
)
|
)
|
||||||
|
|
||||||
dispatch_kwargs = {
|
mlp_output = self._apply_mlp(mlp_compute_input)
|
||||||
"hidden_states": hidden_states,
|
|
||||||
"topk_weights": topk_weights,
|
|
||||||
"topk_ids": topk_ids,
|
|
||||||
"expert_map": expert_map,
|
|
||||||
"global_redundant_expert_num": self.moe_config.global_redundant_expert_num,
|
|
||||||
"mc2_mask": mc2_mask,
|
|
||||||
"apply_router_weight_on_input": apply_router_weight_on_input,
|
|
||||||
"dynamic_eplb": dynamic_eplb,
|
|
||||||
"pertoken_scale": pertoken_scale,
|
|
||||||
}
|
|
||||||
|
|
||||||
if isinstance(self.token_dispatcher, TokenDispatcherWithMC2):
|
|
||||||
dispatch_kwargs["with_quant"] = dispatch_with_quant
|
|
||||||
dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode")
|
|
||||||
dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None
|
|
||||||
dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant
|
|
||||||
else:
|
|
||||||
dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8
|
|
||||||
|
|
||||||
dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs)
|
|
||||||
|
|
||||||
mlp_output = unified_apply_mlp(
|
|
||||||
hidden_states=dispatch_results.hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2=w2,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_bias=w1_bias,
|
|
||||||
w2_bias=w2_bias,
|
|
||||||
activation=activation,
|
|
||||||
group_list=dispatch_results.group_list,
|
|
||||||
dynamic_scale=dispatch_results.dynamic_scale,
|
|
||||||
group_list_type=dispatch_results.group_list_type,
|
|
||||||
w1_scale_bias=w1_scale_bias,
|
|
||||||
w2_scale_bias=w2_scale_bias,
|
|
||||||
w1_offset=w1_offset,
|
|
||||||
w2_offset=w2_offset,
|
|
||||||
topk_scales=dispatch_results.topk_scales,
|
|
||||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant,
|
|
||||||
fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops,
|
|
||||||
need_trans=need_trans,
|
|
||||||
dynamic_eplb=dynamic_eplb,
|
|
||||||
use_mxfp_quant=use_mxfp_quant,
|
|
||||||
act_quant_type=act_quant_type,
|
|
||||||
weight_quant_type=weight_quant_type,
|
|
||||||
scale_type=scale_type,
|
|
||||||
per_token_scale_type=per_token_scale_type,
|
|
||||||
round_mode=round_mode,
|
|
||||||
use_bf16=(hidden_states.dtype == torch.bfloat16),
|
|
||||||
rollback_quant_config=kwargs.get("rollback_quant_config"),
|
|
||||||
)
|
|
||||||
|
|
||||||
before_combine_evt = torch.npu.current_stream().record_event()
|
before_combine_evt = torch.npu.current_stream().record_event()
|
||||||
combine_results = self.token_dispatcher.token_combine(
|
routed_out = self.token_dispatcher.token_combine(
|
||||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
hidden_states=mlp_output,
|
||||||
|
combine_metadata=token_dispatch_output.combine_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FusedExpertsResult(
|
return FusedExpertsResult(
|
||||||
routed_out=combine_results.routed_out,
|
routed_out=routed_out,
|
||||||
before_dispatch_evt=before_dispatch_evt,
|
before_dispatch_evt=before_dispatch_evt,
|
||||||
before_combine_evt=before_combine_evt,
|
before_combine_evt=before_combine_evt,
|
||||||
group_list_type=dispatch_results.group_list_type,
|
group_list_type=token_dispatch_output.group_list_type,
|
||||||
expert_tokens=dispatch_results.group_list,
|
expert_tokens=token_dispatch_output.group_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||||
|
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_token_dispatcher(self) -> MoETokenDispatcher:
|
def _get_token_dispatcher(self) -> MoETokenDispatcher:
|
||||||
raise NotImplementedError("_get_token_dispatcher function not implemented.")
|
raise NotImplementedError("_get_token_dispatcher function not implemented.")
|
||||||
@@ -317,54 +262,32 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
|
|
||||||
def fused_experts(
|
def fused_experts(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
fused_experts_input: MoEFusedExpertsInput,
|
||||||
w1: torch.Tensor | list[torch.Tensor],
|
|
||||||
w2: torch.Tensor | list[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
activation: str = "silu",
|
|
||||||
w1_bias: torch.Tensor = None,
|
|
||||||
w2_bias: torch.Tensor = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int4_w4a8: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
w1_scale: list[torch.Tensor] | None = None,
|
|
||||||
w2_scale: list[torch.Tensor] | None = None,
|
|
||||||
w1_scale_bias: torch.Tensor = None,
|
|
||||||
w2_scale_bias: torch.Tensor = None,
|
|
||||||
w1_offset: torch.Tensor | None = None,
|
|
||||||
w2_offset: torch.Tensor | None = None,
|
|
||||||
# For load balance
|
|
||||||
log2phy: torch.Tensor = None,
|
|
||||||
need_trans: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
mc2_mask: torch.Tensor = None,
|
|
||||||
pertoken_scale: torch.Tensor | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
assert not (fused_experts_input.weights.w1_scale is None or fused_experts_input.weights.w2_scale is None), (
|
||||||
|
"w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
|
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
|
||||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply log2phy if needed
|
# Apply log2phy if needed
|
||||||
if log2phy is not None:
|
topk_ids = fused_experts_input.topk_ids
|
||||||
topk_ids = log2phy[topk_ids]
|
if fused_experts_input.routing.log2phy is not None:
|
||||||
|
topk_ids = fused_experts_input.routing.log2phy[topk_ids]
|
||||||
|
|
||||||
expert_tokens = None
|
expert_tokens = None
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(fused_experts_input.hidden_states)
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||||
x=hidden_states,
|
x=fused_experts_input.hidden_states,
|
||||||
weight1=w1,
|
weight1=fused_experts_input.weights.w1,
|
||||||
weight2=w2,
|
weight2=fused_experts_input.weights.w2,
|
||||||
expert_idx=topk_ids,
|
expert_idx=topk_ids,
|
||||||
scale1=w1_scale,
|
scale1=fused_experts_input.weights.w1_scale,
|
||||||
scale2=w2_scale,
|
scale2=fused_experts_input.weights.w2_scale,
|
||||||
probs=topk_weights.to(torch.float32),
|
probs=fused_experts_input.topk_weights.to(torch.float32),
|
||||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||||
max_output_size=65536,
|
max_output_size=65536,
|
||||||
out=out,
|
out=out,
|
||||||
@@ -372,16 +295,16 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
)
|
)
|
||||||
expert_tokens = self.expert_token_nums
|
expert_tokens = self.expert_token_nums
|
||||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
assert expert_map is not None, "expert_map cannot be None."
|
assert fused_experts_input.routing.expert_map is not None, "expert_map cannot be None."
|
||||||
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||||
x=hidden_states,
|
x=fused_experts_input.hidden_states,
|
||||||
expert_ids=topk_ids,
|
expert_ids=topk_ids,
|
||||||
gmm1_permuted_weight=w1,
|
gmm1_permuted_weight=fused_experts_input.weights.w1,
|
||||||
gmm1_permuted_weight_scale=w1_scale,
|
gmm1_permuted_weight_scale=fused_experts_input.weights.w1_scale,
|
||||||
gmm2_weight=w2,
|
gmm2_weight=fused_experts_input.weights.w2,
|
||||||
gmm2_weight_scale=w2_scale,
|
gmm2_weight_scale=fused_experts_input.weights.w2_scale,
|
||||||
expert_smooth_scales=None,
|
expert_smooth_scales=None,
|
||||||
expert_scales=topk_weights.to(torch.float32),
|
expert_scales=fused_experts_input.topk_weights.to(torch.float32),
|
||||||
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
||||||
ep_rank_size=self.token_dispatcher.ep_world_size,
|
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||||
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from vllm_ascend.device.mxfp_compat import (
|
|||||||
ensure_mxfp8_moe_available,
|
ensure_mxfp8_moe_available,
|
||||||
)
|
)
|
||||||
from vllm_ascend.ops.activation import AscendSwigluOAIAndMul
|
from vllm_ascend.ops.activation import AscendSwigluOAIAndMul
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
dispose_tensor,
|
dispose_tensor,
|
||||||
enable_custom_op,
|
enable_custom_op,
|
||||||
@@ -95,27 +96,17 @@ def quant_apply_mlp(
|
|||||||
w2_offset: torch.Tensor | None = None,
|
w2_offset: torch.Tensor | None = None,
|
||||||
fusion: bool = False,
|
fusion: bool = False,
|
||||||
dynamic_eplb: bool = False,
|
dynamic_eplb: bool = False,
|
||||||
**kwargs,
|
use_mxfp_quant: bool = False,
|
||||||
|
act_quant_type: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
weight_quant_type: torch.dtype | None = None,
|
||||||
|
scale_type: torch.dtype | None = None,
|
||||||
|
per_token_scale_type: torch.dtype | None = None,
|
||||||
|
use_bf16: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
|
|
||||||
# quantization modes will be consolidated into a dataclass in a follow-up.
|
|
||||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
|
||||||
act_quant_type = torch.float8_e4m3fn
|
|
||||||
weight_quant_type = None
|
|
||||||
scale_type = None
|
|
||||||
per_token_scale_type = None
|
|
||||||
use_bf16 = True
|
|
||||||
|
|
||||||
input_hidden_dtype = hidden_states.dtype
|
input_hidden_dtype = hidden_states.dtype
|
||||||
use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb)
|
use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb)
|
||||||
|
|
||||||
if use_mxfp_quant:
|
if use_mxfp_quant:
|
||||||
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
|
|
||||||
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
|
|
||||||
scale_type = kwargs.get("scale_type")
|
|
||||||
per_token_scale_type = kwargs.get("per_token_scale_type")
|
|
||||||
use_bf16 = kwargs.get("use_bf16", True)
|
|
||||||
|
|
||||||
ensure_mxfp8_moe_available("MXFP MoE MLP path")
|
ensure_mxfp8_moe_available("MXFP MoE MLP path")
|
||||||
|
|
||||||
if w1_scale_bias is not None or w2_scale_bias is not None:
|
if w1_scale_bias is not None or w2_scale_bias is not None:
|
||||||
@@ -393,34 +384,32 @@ def unquant_apply_mlp(
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def unified_apply_mlp(
|
def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor | list[torch.Tensor],
|
|
||||||
w2: torch.Tensor | list[torch.Tensor],
|
|
||||||
group_list: torch.Tensor,
|
|
||||||
w1_scale: list[torch.Tensor] | None = None,
|
|
||||||
w2_scale: list[torch.Tensor] | None = None,
|
|
||||||
activation: str | None = None,
|
|
||||||
w1_bias: torch.Tensor = None,
|
|
||||||
w2_bias: torch.Tensor = None,
|
|
||||||
dynamic_scale: torch.Tensor = None,
|
|
||||||
group_list_type: int = 1,
|
|
||||||
w1_scale_bias: torch.Tensor = None,
|
|
||||||
w2_scale_bias: torch.Tensor = None,
|
|
||||||
w1_offset: torch.Tensor | None = None,
|
|
||||||
w2_offset: torch.Tensor | None = None,
|
|
||||||
topk_scales: torch.Tensor | None = None,
|
|
||||||
with_quant: bool = False,
|
|
||||||
fusion: bool = False,
|
|
||||||
need_trans: bool = True,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Unified MoE MLP entry.
|
Unified MoE MLP entry.
|
||||||
Quant path is dispatched by DeviceOperator with explicit quant-type flags.
|
Quant path is dispatched by DeviceOperator with explicit typed kernel flags.
|
||||||
"""
|
"""
|
||||||
if not with_quant:
|
hidden_states = mlp_compute_input.hidden_states
|
||||||
|
group_list = mlp_compute_input.group_list
|
||||||
|
group_list_type = mlp_compute_input.group_list_type
|
||||||
|
dynamic_scale = mlp_compute_input.dynamic_scale
|
||||||
|
topk_scales = mlp_compute_input.topk_scales
|
||||||
|
w1 = mlp_compute_input.weights.w1
|
||||||
|
w2 = mlp_compute_input.weights.w2
|
||||||
|
w1_bias = mlp_compute_input.weights.w1_bias
|
||||||
|
w2_bias = mlp_compute_input.weights.w2_bias
|
||||||
|
w1_scale = mlp_compute_input.weights.w1_scale
|
||||||
|
w2_scale = mlp_compute_input.weights.w2_scale
|
||||||
|
w1_scale_bias = mlp_compute_input.weights.w1_scale_bias
|
||||||
|
w2_scale_bias = mlp_compute_input.weights.w2_scale_bias
|
||||||
|
w1_offset = mlp_compute_input.weights.w1_offset
|
||||||
|
w2_offset = mlp_compute_input.weights.w2_offset
|
||||||
|
activation = mlp_compute_input.activation
|
||||||
|
need_trans = mlp_compute_input.need_trans
|
||||||
|
dynamic_eplb = mlp_compute_input.dynamic_eplb
|
||||||
|
fusion = mlp_compute_input.fusion
|
||||||
|
|
||||||
|
if not mlp_compute_input.quant.is_quant:
|
||||||
return unquant_apply_mlp(
|
return unquant_apply_mlp(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@@ -435,13 +424,22 @@ def unified_apply_mlp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert w1_scale is not None and w2_scale is not None
|
assert w1_scale is not None and w2_scale is not None
|
||||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
|
act_quant_type = torch.float8_e4m3fn
|
||||||
# quantization modes will be consolidated into a dataclass in a follow-up.
|
weight_quant_type = torch.float8_e4m3fn
|
||||||
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
|
scale_type = None
|
||||||
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
|
per_token_scale_type = None
|
||||||
scale_type = kwargs.get("scale_type")
|
use_bf16 = hidden_states.dtype == torch.bfloat16
|
||||||
per_token_scale_type = kwargs.get("per_token_scale_type")
|
use_mxfp_quant = mlp_compute_input.quant.is_mxfp
|
||||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
|
||||||
|
if use_mxfp_quant:
|
||||||
|
mxfp = mlp_compute_input.quant.mxfp
|
||||||
|
assert mxfp is not None, "mlp_compute_input.quant.mxfp is required when quant_type is MXFP8."
|
||||||
|
act_quant_type = mxfp.act_quant_type or act_quant_type
|
||||||
|
weight_quant_type = mxfp.weight_quant_type or weight_quant_type
|
||||||
|
scale_type = mxfp.scale_dtype
|
||||||
|
per_token_scale_type = mxfp.per_token_scale_dtype
|
||||||
|
use_bf16 = mxfp.use_bf16
|
||||||
|
|
||||||
return quant_apply_mlp(
|
return quant_apply_mlp(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@@ -457,10 +455,10 @@ def unified_apply_mlp(
|
|||||||
w2_offset=w2_offset,
|
w2_offset=w2_offset,
|
||||||
fusion=fusion,
|
fusion=fusion,
|
||||||
dynamic_eplb=dynamic_eplb,
|
dynamic_eplb=dynamic_eplb,
|
||||||
|
use_mxfp_quant=use_mxfp_quant,
|
||||||
act_quant_type=act_quant_type,
|
act_quant_type=act_quant_type,
|
||||||
weight_quant_type=weight_quant_type,
|
weight_quant_type=weight_quant_type,
|
||||||
scale_type=scale_type,
|
scale_type=scale_type,
|
||||||
per_token_scale_type=per_token_scale_type,
|
per_token_scale_type=per_token_scale_type,
|
||||||
use_mxfp_quant=use_mxfp_quant,
|
use_bf16=use_bf16,
|
||||||
use_bf16=kwargs.get("use_bf16", True),
|
|
||||||
)
|
)
|
||||||
|
|||||||
244
vllm_ascend/ops/fused_moe/moe_runtime_args.py
Normal file
244
vllm_ascend/ops/fused_moe/moe_runtime_args.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""Typed runtime contracts and builders for fused MoE execution.
|
||||||
|
|
||||||
|
This module is the single entry point for the runtime payloads used across the
|
||||||
|
fused MoE pipeline.
|
||||||
|
|
||||||
|
Relationship overview:
|
||||||
|
|
||||||
|
stage params: reusable sub-payloads
|
||||||
|
- MoERoutingParams
|
||||||
|
- MoEQuantParams
|
||||||
|
- internal MXFP leaf: MoEMxfpParams
|
||||||
|
|
||||||
|
stage contracts: stage input/output payloads
|
||||||
|
prepare
|
||||||
|
-> MoEPrepareOutput
|
||||||
|
|
||||||
|
fused_experts input
|
||||||
|
-> MoEFusedExpertsInput
|
||||||
|
|- weights: MoEWeights
|
||||||
|
|- routing: MoERoutingParams
|
||||||
|
|- quant: MoEQuantParams
|
||||||
|
|
||||||
|
dispatch
|
||||||
|
input -> MoETokenDispatchInput
|
||||||
|
output -> MoETokenDispatchOutput[TMoECombineMetadata]
|
||||||
|
TMoECombineMetadata is one of:
|
||||||
|
- MoEAllGatherCombineMetadata
|
||||||
|
- MoEAllToAllCombineMetadata
|
||||||
|
- MoEMC2CombineMetadata
|
||||||
|
|
||||||
|
mlp
|
||||||
|
input -> MoEMlpComputeInput
|
||||||
|
|
||||||
|
combine
|
||||||
|
output -> torch.Tensor
|
||||||
|
|
||||||
|
The helper builders below adapt legacy call sites into these typed contracts.
|
||||||
|
Only the fused_moe package should need to know about the internal MXFP leaf
|
||||||
|
dataclass directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm_ascend.ops.fused_moe.moe_stage_params as _stage_params
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_stage_contracts import (
|
||||||
|
MoEAllGatherCombineMetadata,
|
||||||
|
MoEAllToAllCombineMetadata,
|
||||||
|
MoEFusedExpertsInput,
|
||||||
|
MoEMC2CombineMetadata,
|
||||||
|
MoEMlpComputeInput,
|
||||||
|
MoEPrepareOutput,
|
||||||
|
MoETokenDispatchInput,
|
||||||
|
MoETokenDispatchOutput,
|
||||||
|
MoEWeights,
|
||||||
|
TMoECombineMetadata,
|
||||||
|
)
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_stage_params import (
|
||||||
|
MoEQuantParams,
|
||||||
|
MoERoutingParams,
|
||||||
|
)
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mxfp_params(
|
||||||
|
*,
|
||||||
|
quant_type: QuantType,
|
||||||
|
mxfp_act_quant_type: torch.dtype | None = None,
|
||||||
|
mxfp_weight_quant_type: torch.dtype | None = None,
|
||||||
|
mxfp_scale_dtype: torch.dtype | None = None,
|
||||||
|
mxfp_per_token_scale_dtype: torch.dtype | None = None,
|
||||||
|
mxfp_use_bf16: bool | None = None,
|
||||||
|
) -> _stage_params.MoEMxfpParams | None:
|
||||||
|
if quant_type != QuantType.MXFP8:
|
||||||
|
return None
|
||||||
|
|
||||||
|
has_explicit_mxfp_args = any(
|
||||||
|
value is not None
|
||||||
|
for value in (
|
||||||
|
mxfp_act_quant_type,
|
||||||
|
mxfp_weight_quant_type,
|
||||||
|
mxfp_scale_dtype,
|
||||||
|
mxfp_per_token_scale_dtype,
|
||||||
|
mxfp_use_bf16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not has_explicit_mxfp_args:
|
||||||
|
raise ValueError("primitive MXFP params are required when quant_type is QuantType.MXFP8.")
|
||||||
|
|
||||||
|
return _stage_params.MoEMxfpParams(
|
||||||
|
act_quant_type=mxfp_act_quant_type,
|
||||||
|
weight_quant_type=mxfp_weight_quant_type,
|
||||||
|
scale_dtype=mxfp_scale_dtype,
|
||||||
|
per_token_scale_dtype=mxfp_per_token_scale_dtype,
|
||||||
|
use_bf16=True if mxfp_use_bf16 is None else mxfp_use_bf16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_fused_experts_input(
|
||||||
|
*,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1: torch.Tensor | list[torch.Tensor],
|
||||||
|
w2: torch.Tensor | list[torch.Tensor],
|
||||||
|
quant_type: QuantType,
|
||||||
|
dynamic_eplb: bool,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
global_redundant_expert_num: int = 0,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
log2phy: torch.Tensor | None = None,
|
||||||
|
pertoken_scale: torch.Tensor | None = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
need_trans: bool = False,
|
||||||
|
w1_bias: torch.Tensor | None = None,
|
||||||
|
w2_bias: torch.Tensor | None = None,
|
||||||
|
comm_quant_mode: int | None = None,
|
||||||
|
mxfp_act_quant_type: torch.dtype | None = None,
|
||||||
|
mxfp_weight_quant_type: torch.dtype | None = None,
|
||||||
|
mxfp_scale_dtype: torch.dtype | None = None,
|
||||||
|
mxfp_per_token_scale_dtype: torch.dtype | None = None,
|
||||||
|
mxfp_use_bf16: bool | None = None,
|
||||||
|
w1_scale: list[torch.Tensor] | torch.Tensor | None = None,
|
||||||
|
w2_scale: list[torch.Tensor] | torch.Tensor | None = None,
|
||||||
|
w1_scale_bias: torch.Tensor | None = None,
|
||||||
|
w2_scale_bias: torch.Tensor | None = None,
|
||||||
|
w1_offset: torch.Tensor | None = None,
|
||||||
|
w2_offset: torch.Tensor | None = None,
|
||||||
|
) -> MoEFusedExpertsInput:
|
||||||
|
return MoEFusedExpertsInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
weights=MoEWeights(
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
w1_bias=w1_bias,
|
||||||
|
w2_bias=w2_bias,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
w1_scale_bias=w1_scale_bias,
|
||||||
|
w2_scale_bias=w2_scale_bias,
|
||||||
|
w1_offset=w1_offset,
|
||||||
|
w2_offset=w2_offset,
|
||||||
|
),
|
||||||
|
routing=MoERoutingParams(
|
||||||
|
expert_map=expert_map,
|
||||||
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
log2phy=log2phy,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
),
|
||||||
|
activation=activation,
|
||||||
|
need_trans=need_trans,
|
||||||
|
dynamic_eplb=dynamic_eplb,
|
||||||
|
quant=MoEQuantParams(
|
||||||
|
quant_type=quant_type,
|
||||||
|
comm_quant_mode=comm_quant_mode,
|
||||||
|
mxfp=_build_mxfp_params(
|
||||||
|
quant_type=quant_type,
|
||||||
|
mxfp_act_quant_type=mxfp_act_quant_type,
|
||||||
|
mxfp_weight_quant_type=mxfp_weight_quant_type,
|
||||||
|
mxfp_scale_dtype=mxfp_scale_dtype,
|
||||||
|
mxfp_per_token_scale_dtype=mxfp_per_token_scale_dtype,
|
||||||
|
mxfp_use_bf16=mxfp_use_bf16,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_token_dispatch_input(
|
||||||
|
*,
|
||||||
|
fused_experts_input: MoEFusedExpertsInput,
|
||||||
|
topk_ids: torch.Tensor | None = None,
|
||||||
|
) -> MoETokenDispatchInput:
|
||||||
|
return MoETokenDispatchInput(
|
||||||
|
hidden_states=fused_experts_input.hidden_states,
|
||||||
|
topk_weights=fused_experts_input.topk_weights,
|
||||||
|
topk_ids=fused_experts_input.topk_ids if topk_ids is None else topk_ids,
|
||||||
|
routing=fused_experts_input.routing,
|
||||||
|
quant=fused_experts_input.quant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_mlp_compute_input(
|
||||||
|
*,
|
||||||
|
fused_experts_input: MoEFusedExpertsInput,
|
||||||
|
token_dispatch_output: MoETokenDispatchOutput[TMoECombineMetadata],
|
||||||
|
use_fusion_ops: bool,
|
||||||
|
) -> MoEMlpComputeInput:
|
||||||
|
if fused_experts_input.quant.is_mxfp and fused_experts_input.quant.mxfp is None:
|
||||||
|
raise ValueError("fused_experts_input.quant.mxfp is required when quant_type is QuantType.MXFP8.")
|
||||||
|
|
||||||
|
return MoEMlpComputeInput(
|
||||||
|
hidden_states=token_dispatch_output.hidden_states,
|
||||||
|
group_list=token_dispatch_output.group_list,
|
||||||
|
group_list_type=token_dispatch_output.group_list_type,
|
||||||
|
dynamic_scale=token_dispatch_output.dynamic_scale,
|
||||||
|
topk_scales=token_dispatch_output.topk_scales,
|
||||||
|
weights=fused_experts_input.weights,
|
||||||
|
quant=fused_experts_input.quant,
|
||||||
|
fusion=fused_experts_input.quant.quant_type in (QuantType.W8A8, QuantType.MXFP8) and use_fusion_ops,
|
||||||
|
activation=fused_experts_input.activation,
|
||||||
|
need_trans=fused_experts_input.need_trans,
|
||||||
|
dynamic_eplb=fused_experts_input.dynamic_eplb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MoEAllGatherCombineMetadata",
|
||||||
|
"MoEAllToAllCombineMetadata",
|
||||||
|
"MoEFusedExpertsInput",
|
||||||
|
"MoEMC2CombineMetadata",
|
||||||
|
"MoEMlpComputeInput",
|
||||||
|
"MoEPrepareOutput",
|
||||||
|
"MoEQuantParams",
|
||||||
|
"MoERoutingParams",
|
||||||
|
"MoETokenDispatchInput",
|
||||||
|
"MoETokenDispatchOutput",
|
||||||
|
"MoEWeights",
|
||||||
|
"TMoECombineMetadata",
|
||||||
|
"build_fused_experts_input",
|
||||||
|
"build_token_dispatch_input",
|
||||||
|
"build_mlp_compute_input",
|
||||||
|
]
|
||||||
154
vllm_ascend/ops/fused_moe/moe_stage_contracts.py
Normal file
154
vllm_ascend/ops/fused_moe/moe_stage_contracts.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEQuantParams, MoERoutingParams
|
||||||
|
|
||||||
|
TMoECombineMetadata = TypeVar("TMoECombineMetadata")
|
||||||
|
|
||||||
|
|
||||||
|
# prepare -> fused_experts
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEPrepareOutput:
|
||||||
|
"""Typed output from prepare stage."""
|
||||||
|
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
router_logits: torch.Tensor
|
||||||
|
mc2_mask: torch.Tensor | None
|
||||||
|
padded_hidden_states_shape: torch.Size | None
|
||||||
|
pertoken_scale: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEWeights:
|
||||||
|
"""Dense and quantized weight payloads consumed by MoE execution."""
|
||||||
|
|
||||||
|
w1: torch.Tensor | list[torch.Tensor]
|
||||||
|
w2: torch.Tensor | list[torch.Tensor]
|
||||||
|
w1_bias: torch.Tensor | None = None
|
||||||
|
w2_bias: torch.Tensor | None = None
|
||||||
|
w1_scale: torch.Tensor | list[torch.Tensor] | None = None
|
||||||
|
w2_scale: torch.Tensor | list[torch.Tensor] | None = None
|
||||||
|
w1_scale_bias: torch.Tensor | None = None
|
||||||
|
w2_scale_bias: torch.Tensor | None = None
|
||||||
|
w1_offset: torch.Tensor | None = None
|
||||||
|
w2_offset: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEFusedExpertsInput:
|
||||||
|
"""Top-level input for the routed experts pipeline."""
|
||||||
|
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
topk_ids: torch.Tensor
|
||||||
|
weights: MoEWeights
|
||||||
|
routing: MoERoutingParams
|
||||||
|
quant: MoEQuantParams
|
||||||
|
activation: str = "silu"
|
||||||
|
need_trans: bool = False
|
||||||
|
dynamic_eplb: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoETokenDispatchInput:
|
||||||
|
"""Input to token dispatch."""
|
||||||
|
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
topk_ids: torch.Tensor
|
||||||
|
routing: MoERoutingParams
|
||||||
|
quant: MoEQuantParams
|
||||||
|
|
||||||
|
|
||||||
|
# dispatch carry-over state consumed by combine
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEMC2CombineMetadata:
|
||||||
|
topk_ids: torch.Tensor
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
expert_map: torch.Tensor | None
|
||||||
|
ep_recv_counts: torch.Tensor
|
||||||
|
tp_recv_counts: torch.Tensor
|
||||||
|
assist_info_for_combine: torch.Tensor
|
||||||
|
expand_scales: torch.Tensor | None
|
||||||
|
dispatch_with_quant: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEAllGatherCombineMetadata:
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
expanded_row_idx: torch.Tensor
|
||||||
|
restore_shape: torch.Size
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEAllToAllCombineMetadata:
|
||||||
|
input_splits: np.ndarray
|
||||||
|
output_splits: np.ndarray
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
reversed_local_input_permutation_mapping: torch.Tensor
|
||||||
|
reversed_global_input_permutation_mapping: torch.Tensor | None
|
||||||
|
hidden_shape: torch.Size
|
||||||
|
hidden_shape_before_permute: torch.Size
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoETokenDispatchOutput(Generic[TMoECombineMetadata]):
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
group_list: torch.Tensor
|
||||||
|
group_list_type: int
|
||||||
|
combine_metadata: TMoECombineMetadata
|
||||||
|
dynamic_scale: torch.Tensor | None = None
|
||||||
|
topk_scales: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# dispatch -> mlp -> combine
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEMlpComputeInput:
|
||||||
|
"""Input to MLP compute."""
|
||||||
|
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
group_list: torch.Tensor
|
||||||
|
group_list_type: int
|
||||||
|
dynamic_scale: torch.Tensor | None
|
||||||
|
topk_scales: torch.Tensor | None
|
||||||
|
weights: MoEWeights
|
||||||
|
quant: MoEQuantParams
|
||||||
|
fusion: bool
|
||||||
|
activation: str = "silu"
|
||||||
|
need_trans: bool = False
|
||||||
|
dynamic_eplb: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MoEPrepareOutput",
|
||||||
|
"MoEWeights",
|
||||||
|
"MoEFusedExpertsInput",
|
||||||
|
"MoETokenDispatchInput",
|
||||||
|
"MoEMC2CombineMetadata",
|
||||||
|
"MoEAllGatherCombineMetadata",
|
||||||
|
"MoEAllToAllCombineMetadata",
|
||||||
|
"MoETokenDispatchOutput",
|
||||||
|
"MoEMlpComputeInput",
|
||||||
|
"TMoECombineMetadata",
|
||||||
|
]
|
||||||
86
vllm_ascend/ops/fused_moe/moe_stage_params.py
Normal file
86
vllm_ascend/ops/fused_moe/moe_stage_params.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoERoutingParams:
|
||||||
|
"""Routing and dispatch side inputs for one MoE invocation.
|
||||||
|
|
||||||
|
`pertoken_scale` is intentionally kept here even though it is not a pure
|
||||||
|
routing concept. It is used by pre-quantized activation flows, currently
|
||||||
|
the AllGather + EP W8A8 prepare path, where prepare emits per-token
|
||||||
|
activation scales and dispatch needs to carry them forward so the MLP
|
||||||
|
quant path can reuse those scales instead of requantizing activations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expert_map: torch.Tensor | None
|
||||||
|
global_redundant_expert_num: int
|
||||||
|
mc2_mask: torch.Tensor | None
|
||||||
|
apply_router_weight_on_input: bool
|
||||||
|
log2phy: torch.Tensor | None = None
|
||||||
|
# Precomputed activation scales from prepare stage for quantized dispatch.
|
||||||
|
pertoken_scale: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEMxfpParams:
|
||||||
|
"""Internal MXFP-only precision settings used by fused_moe runtime."""
|
||||||
|
|
||||||
|
act_quant_type: torch.dtype | None = None
|
||||||
|
weight_quant_type: torch.dtype | None = None
|
||||||
|
scale_dtype: torch.dtype | None = None
|
||||||
|
per_token_scale_dtype: torch.dtype | None = None
|
||||||
|
use_bf16: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MoEQuantParams:
|
||||||
|
"""Quant mode, backend override, and optional internal MXFP leaf config."""
|
||||||
|
|
||||||
|
quant_type: QuantType = QuantType.NONE
|
||||||
|
comm_quant_mode: int | None = None
|
||||||
|
mxfp: MoEMxfpParams | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_quant(self) -> bool:
|
||||||
|
return self.quant_type != QuantType.NONE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mxfp(self) -> bool:
|
||||||
|
return self.quant_type == QuantType.MXFP8
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_int_quant(self) -> bool:
|
||||||
|
return self.quant_type in (QuantType.W8A8, QuantType.W4A8)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dispatch_with_quant(self) -> bool:
|
||||||
|
return self.quant_type in (QuantType.W8A8, QuantType.W4A8, QuantType.MXFP8)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MoERoutingParams",
|
||||||
|
"MoEMxfpParams",
|
||||||
|
"MoEQuantParams",
|
||||||
|
]
|
||||||
@@ -31,7 +31,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ class PrepareAndFinalize(ABC):
|
|||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
quant_type: QuantType = QuantType.NONE,
|
quant_type: QuantType = QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
"""
|
"""
|
||||||
Prepare tensors before MoE computation. May involve:
|
Prepare tensors before MoE computation. May involve:
|
||||||
- Padding to align communication boundaries
|
- Padding to align communication boundaries
|
||||||
@@ -79,16 +80,20 @@ class PrepareAndFinalize(ABC):
|
|||||||
quant_type: none, w8a8, w4a8 or mxfp8
|
quant_type: none, w8a8, w4a8 or mxfp8
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of:
|
MoEPrepareOutput:
|
||||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||||
- processed router_logits (may be recomputed or broadcasted)
|
- processed router_logits (may be recomputed or broadcasted)
|
||||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||||
- optional context metadata (e.g., saved split_hidden_states for finalization)
|
- optional padded hidden state shape for finalization
|
||||||
|
- optional per-token scale for quantized path
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Prepare not implemented.")
|
raise NotImplementedError("Prepare not implemented.")
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool,
|
||||||
|
padded_hidden_states_shape: torch.Size | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Finalize MoE output. May involve:
|
Finalize MoE output. May involve:
|
||||||
@@ -130,7 +135,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
quant_type=QuantType.NONE,
|
quant_type=QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||||
@@ -140,7 +145,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
|
MoEPrepareOutput where `mc2_mask` is None for All2All path.
|
||||||
"""
|
"""
|
||||||
self.replace_allreduce = replace_allreduce
|
self.replace_allreduce = replace_allreduce
|
||||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
@@ -162,12 +167,19 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
hidden_states = split_hidden_states[self.tp_rank]
|
hidden_states = split_hidden_states[self.tp_rank]
|
||||||
router_logits = split_router_logits[self.tp_rank]
|
router_logits = split_router_logits[self.tp_rank]
|
||||||
|
|
||||||
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
|
return MoEPrepareOutput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
return hidden_states, router_logits, None, context_metadata
|
router_logits=router_logits,
|
||||||
|
mc2_mask=None,
|
||||||
|
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||||
|
pertoken_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool,
|
||||||
|
padded_hidden_states_shape: torch.Size | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Finalization steps:
|
Finalization steps:
|
||||||
@@ -180,12 +192,11 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
|
|
||||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
assert context_metadata is not None
|
assert padded_hidden_states_shape is not None
|
||||||
# Cannot reuse `split_hidden_states` from prepare phase as it
|
# Cannot reuse `split_hidden_states` from prepare phase as it
|
||||||
# may share memory with original hidden_states. Since shared
|
# may share memory with original hidden_states. Since shared
|
||||||
# experts may use the original tensor, reusing it would cause
|
# experts may use the original tensor, reusing it would cause
|
||||||
# in-place modification during all_gather, corrupting the data.
|
# in-place modification during all_gather, corrupting the data.
|
||||||
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
|
|
||||||
gathered_hidden_states = torch.empty(
|
gathered_hidden_states = torch.empty(
|
||||||
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
|
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
|
||||||
)
|
)
|
||||||
@@ -227,7 +238,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
quant_type=QuantType.NONE,
|
quant_type=QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||||
@@ -238,7 +249,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
|
MoEPrepareOutput, possibly sliced/padded.
|
||||||
"""
|
"""
|
||||||
self.replace_allreduce = replace_allreduce
|
self.replace_allreduce = replace_allreduce
|
||||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
@@ -267,11 +278,13 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
hidden_states = split_hidden_states[self.tp_rank]
|
hidden_states = split_hidden_states[self.tp_rank]
|
||||||
router_logits = split_router_logits[self.tp_rank]
|
router_logits = split_router_logits[self.tp_rank]
|
||||||
|
|
||||||
context_metadata = {
|
return MoEPrepareOutput(
|
||||||
"padded_hidden_states_shape": padded_hidden_states_shape,
|
hidden_states=hidden_states,
|
||||||
}
|
router_logits=router_logits,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||||
|
pertoken_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||||
@@ -303,13 +316,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
quant_type=QuantType.NONE,
|
quant_type=QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
AllGather hidden_states and router_logits to form global tensors.
|
AllGather hidden_states and router_logits to form global tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (global_hidden_states, global_router_logits, None)
|
MoEPrepareOutput with global tensors.
|
||||||
"""
|
"""
|
||||||
if enable_sp():
|
if enable_sp():
|
||||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||||
@@ -318,7 +331,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
|
|
||||||
def _prepare_with_ep_group(
|
def _prepare_with_ep_group(
|
||||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
|
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
pertoken_scale = None
|
pertoken_scale = None
|
||||||
if quant_type == QuantType.W8A8:
|
if quant_type == QuantType.W8A8:
|
||||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||||
@@ -342,10 +355,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
if self.multistream_overlap_gate:
|
if self.multistream_overlap_gate:
|
||||||
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
|
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
|
||||||
|
|
||||||
if pertoken_scale is not None:
|
return MoEPrepareOutput(
|
||||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
return hidden_states, router_logits, None, None
|
mc2_mask=None,
|
||||||
|
padded_hidden_states_shape=None,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_with_dp_group(
|
def _prepare_with_dp_group(
|
||||||
self,
|
self,
|
||||||
@@ -354,7 +370,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
quant_type=QuantType.NONE,
|
quant_type=QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
) -> MoEPrepareOutput:
|
||||||
"""
|
"""
|
||||||
Preparation steps:
|
Preparation steps:
|
||||||
1. Fetch max token count across DP group from forward context.
|
1. Fetch max token count across DP group from forward context.
|
||||||
@@ -362,7 +378,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
3. All-gather across DP group to form global input tensor.
|
3. All-gather across DP group to form global input tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (global_hidden_states, global_router_logits, None, None)
|
MoEPrepareOutput with global tensors.
|
||||||
"""
|
"""
|
||||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
if self.moe_config.dp_size > 1:
|
if self.moe_config.dp_size > 1:
|
||||||
@@ -396,10 +412,19 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states, router_logits, None, None
|
return MoEPrepareOutput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
mc2_mask=None,
|
||||||
|
padded_hidden_states_shape=None,
|
||||||
|
pertoken_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool,
|
||||||
|
padded_hidden_states_shape: torch.Size | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Finalization steps:
|
Finalization steps:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from typing import Generic
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
@@ -31,25 +31,18 @@ from vllm.distributed.parallel_state import get_ep_group
|
|||||||
from vllm_ascend.device.device_op import DeviceOperator
|
from vllm_ascend.device.device_op import DeviceOperator
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||||
|
MoEAllGatherCombineMetadata,
|
||||||
|
MoEAllToAllCombineMetadata,
|
||||||
|
MoEMC2CombineMetadata,
|
||||||
|
MoETokenDispatchInput,
|
||||||
|
MoETokenDispatchOutput,
|
||||||
|
TMoECombineMetadata,
|
||||||
|
)
|
||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
|
||||||
class TokenDispatchResult:
|
|
||||||
hidden_states: torch.Tensor
|
|
||||||
group_list: torch.Tensor
|
|
||||||
group_list_type: int
|
|
||||||
dynamic_scale: torch.Tensor | None = field(default=None)
|
|
||||||
topk_scales: torch.Tensor | None = field(default=None)
|
|
||||||
context_metadata: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TokenCombineResult:
|
|
||||||
routed_out: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class MoETokenDispatcher(ABC):
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the MoE Token Dispatcher.
|
Initialize the MoE Token Dispatcher.
|
||||||
@@ -73,27 +66,21 @@ class MoETokenDispatcher(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def token_dispatch(
|
def token_dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
topk_weights: torch.Tensor,
|
) -> MoETokenDispatchOutput[TMoECombineMetadata]:
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
mc2_mask: torch.Tensor | None = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
with_quant: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
pertoken_scale: torch.Tensor | None = None,
|
|
||||||
) -> TokenDispatchResult:
|
|
||||||
raise NotImplementedError("Dispatch function not implemented.")
|
raise NotImplementedError("Dispatch function not implemented.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def token_combine(
|
def token_combine(
|
||||||
self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None
|
self,
|
||||||
) -> TokenCombineResult:
|
hidden_states: torch.Tensor,
|
||||||
|
combine_metadata: TMoECombineMetadata,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError("Combine function not implemented.")
|
raise NotImplementedError("Combine function not implemented.")
|
||||||
|
|
||||||
|
|
||||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
device_group = get_mc2_group().device_group
|
device_group = get_mc2_group().device_group
|
||||||
@@ -110,7 +97,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||||
# improve communication performance.
|
# improve communication performance.
|
||||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||||
self.with_quant = False
|
|
||||||
|
|
||||||
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
||||||
# dispatch & combine operators with different input num_tokens per rank.
|
# dispatch & combine operators with different input num_tokens per rank.
|
||||||
@@ -131,25 +117,23 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
|
|
||||||
def get_dispatch_mc2_kwargs(
|
def get_dispatch_mc2_kwargs(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor,
|
|
||||||
mc2_mask: torch.Tensor,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
hidden_states = token_dispatch_input.hidden_states
|
||||||
comm_quant_mode = kwargs.get("comm_quant_mode")
|
topk_weights = token_dispatch_input.topk_weights
|
||||||
|
topk_ids = token_dispatch_input.topk_ids
|
||||||
|
expert_map = token_dispatch_input.routing.expert_map
|
||||||
|
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
|
||||||
|
comm_quant_mode = token_dispatch_input.quant.comm_quant_mode
|
||||||
|
|
||||||
|
assert expert_map is not None, "expert_map is required for MC2 token dispatch."
|
||||||
# NOTE: quant_mode differs by quant feature:
|
# NOTE: quant_mode differs by quant feature:
|
||||||
# - Legacy int communication quantization uses quant_mode=2.
|
# - Legacy int communication quantization uses quant_mode=2.
|
||||||
# - A5 MXFP8 communication uses quant_mode=4.
|
# - A5 MXFP8 communication uses quant_mode=4.
|
||||||
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
|
|
||||||
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
|
|
||||||
if comm_quant_mode is not None:
|
if comm_quant_mode is not None:
|
||||||
quant_mode = comm_quant_mode
|
quant_mode = comm_quant_mode
|
||||||
elif self.with_quant:
|
elif token_dispatch_input.quant.dispatch_with_quant:
|
||||||
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2
|
quant_mode = 4 if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp else 2
|
||||||
else:
|
else:
|
||||||
quant_mode = 0
|
quant_mode = 0
|
||||||
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||||
@@ -178,10 +162,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
"tp_rank_id": 0,
|
"tp_rank_id": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.a5_need_extra_args and use_mxfp_quant:
|
if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp:
|
||||||
y_dtype = kwargs.get("y_dtype")
|
y_dtype = torch.float8_e4m3fn
|
||||||
if self.with_quant:
|
if (
|
||||||
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype
|
token_dispatch_input.quant.mxfp is not None
|
||||||
|
and token_dispatch_input.quant.mxfp.act_quant_type is not None
|
||||||
|
):
|
||||||
|
y_dtype = token_dispatch_input.quant.mxfp.act_quant_type
|
||||||
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
|
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
|
||||||
if self.need_expert_scale or self.a5_need_extra_args:
|
if self.need_expert_scale or self.a5_need_extra_args:
|
||||||
stage1_kwargs.update(
|
stage1_kwargs.update(
|
||||||
@@ -195,22 +182,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
|
|
||||||
def token_dispatch(
|
def token_dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
mc2_mask: torch.Tensor | None = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
with_quant: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
pertoken_scale: torch.Tensor | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self.with_quant = with_quant
|
kwargs_mc2 = self.get_dispatch_mc2_kwargs(token_dispatch_input)
|
||||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
|
|
||||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
|
|
||||||
)
|
|
||||||
output = (
|
output = (
|
||||||
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
|
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
|
||||||
if self.enable_dispatch_v2
|
if self.enable_dispatch_v2
|
||||||
@@ -227,33 +201,32 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
expand_scales,
|
expand_scales,
|
||||||
) = output[0:7]
|
) = output[0:7]
|
||||||
|
|
||||||
context_metadata = {
|
|
||||||
"topk_ids": topk_ids,
|
|
||||||
"topk_weights": topk_weights,
|
|
||||||
"expert_map": expert_map,
|
|
||||||
"ep_recv_counts": ep_recv_counts,
|
|
||||||
"tp_recv_counts": tp_recv_counts,
|
|
||||||
"assist_info_for_combine": assist_info_for_combine,
|
|
||||||
"expand_scales": expand_scales,
|
|
||||||
}
|
|
||||||
|
|
||||||
group_list_type = 0
|
group_list_type = 0
|
||||||
return TokenDispatchResult(
|
return MoETokenDispatchOutput(
|
||||||
hidden_states=expand_x,
|
hidden_states=expand_x,
|
||||||
dynamic_scale=dynamic_scale,
|
dynamic_scale=dynamic_scale,
|
||||||
group_list=expert_token_nums,
|
group_list=expert_token_nums,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
context_metadata=context_metadata,
|
combine_metadata=MoEMC2CombineMetadata(
|
||||||
|
topk_ids=token_dispatch_input.topk_ids,
|
||||||
|
topk_weights=token_dispatch_input.topk_weights,
|
||||||
|
expert_map=token_dispatch_input.routing.expert_map,
|
||||||
|
ep_recv_counts=ep_recv_counts,
|
||||||
|
tp_recv_counts=tp_recv_counts,
|
||||||
|
assist_info_for_combine=assist_info_for_combine,
|
||||||
|
expand_scales=expand_scales,
|
||||||
|
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict):
|
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, combine_metadata: MoEMC2CombineMetadata):
|
||||||
expert_map = context_metadata["expert_map"]
|
expert_map = combine_metadata.expert_map
|
||||||
topk_ids = context_metadata["topk_ids"]
|
topk_ids = combine_metadata.topk_ids
|
||||||
topk_weights = context_metadata["topk_weights"]
|
topk_weights = combine_metadata.topk_weights
|
||||||
ep_recv_counts = context_metadata["ep_recv_counts"]
|
ep_recv_counts = combine_metadata.ep_recv_counts
|
||||||
tp_recv_counts = context_metadata["tp_recv_counts"]
|
tp_recv_counts = combine_metadata.tp_recv_counts
|
||||||
assist_info_for_combine = context_metadata["assist_info_for_combine"]
|
assist_info_for_combine = combine_metadata.assist_info_for_combine
|
||||||
expand_scales = context_metadata["expand_scales"]
|
expand_scales = combine_metadata.expand_scales
|
||||||
|
|
||||||
assert expert_map is not None
|
assert expert_map is not None
|
||||||
|
|
||||||
@@ -267,7 +240,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
"global_bs": self.global_bs,
|
"global_bs": self.global_bs,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.with_quant:
|
if combine_metadata.dispatch_with_quant:
|
||||||
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
|
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
|
||||||
stage3_kwargs = {
|
stage3_kwargs = {
|
||||||
@@ -296,52 +269,44 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
kwargs_mc2.update(stage3_kwargs)
|
kwargs_mc2.update(stage3_kwargs)
|
||||||
return kwargs_mc2
|
return kwargs_mc2
|
||||||
|
|
||||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
def token_combine(self, hidden_states, combine_metadata, bias=None):
|
||||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||||
|
|
||||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata)
|
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, combine_metadata)
|
||||||
combined_output = (
|
combined_output = (
|
||||||
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
|
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
|
||||||
if self.enable_dispatch_v2
|
if self.enable_dispatch_v2
|
||||||
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||||
)
|
)
|
||||||
|
|
||||||
return TokenCombineResult(
|
return combined_output
|
||||||
routed_out=combined_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
class TokenDispatcherWithAllGather(MoETokenDispatcher[MoEAllGatherCombineMetadata]):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.apply_router_weight_on_input = False
|
|
||||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||||
num_experts_local = kwargs.get("num_local_experts", 0)
|
num_experts_local = kwargs.get("num_local_experts", 0)
|
||||||
self.num_experts_local = (
|
self.num_experts_local = (
|
||||||
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
|
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
|
||||||
)
|
)
|
||||||
self.original_shape = None
|
|
||||||
self.with_quant = False
|
|
||||||
|
|
||||||
def token_dispatch(
|
def token_dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
mc2_mask: torch.Tensor | None = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
with_quant: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
pertoken_scale: torch.Tensor | None = None,
|
|
||||||
):
|
):
|
||||||
self.with_quant = with_quant
|
with_quant = token_dispatch_input.quant.is_int_quant
|
||||||
self.original_shape = hidden_states.shape
|
hidden_states = token_dispatch_input.hidden_states
|
||||||
|
topk_weights = token_dispatch_input.topk_weights
|
||||||
|
topk_ids = token_dispatch_input.topk_ids
|
||||||
|
expert_map = token_dispatch_input.routing.expert_map
|
||||||
|
pertoken_scale = token_dispatch_input.routing.pertoken_scale
|
||||||
|
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
|
||||||
|
restore_shape = hidden_states.shape
|
||||||
|
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
num_tokens = hidden_states.shape[:-1].numel()
|
||||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
|
||||||
if self.apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
_, topk = topk_weights.shape
|
_, topk = topk_weights.shape
|
||||||
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
@@ -365,35 +330,37 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
expert_tokens_num_type=1,
|
expert_tokens_num_type=1,
|
||||||
expert_tokens_num_flag=True,
|
expert_tokens_num_flag=True,
|
||||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
quant_mode=1 if with_quant and pertoken_scale is None else -1,
|
||||||
)
|
)
|
||||||
expert_tokens = expert_tokens.to(torch.int64)
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
group_list_type = 1 # `count` mode
|
group_list_type = 1 # `count` mode
|
||||||
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
|
|
||||||
|
|
||||||
return TokenDispatchResult(
|
return MoETokenDispatchOutput(
|
||||||
hidden_states=sorted_hidden_states,
|
hidden_states=sorted_hidden_states,
|
||||||
dynamic_scale=pertoken_scale if self.with_quant else None,
|
dynamic_scale=pertoken_scale if with_quant else None,
|
||||||
group_list=expert_tokens,
|
group_list=expert_tokens,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
context_metadata=context_metadata,
|
combine_metadata=MoEAllGatherCombineMetadata(
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
expanded_row_idx=expanded_row_idx,
|
||||||
|
restore_shape=restore_shape,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
def token_combine(self, hidden_states, combine_metadata, bias=None):
|
||||||
assert self.original_shape is not None
|
|
||||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||||
permuted_tokens=hidden_states,
|
permuted_tokens=hidden_states,
|
||||||
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
|
sorted_indices=torch.abs(combine_metadata.expanded_row_idx),
|
||||||
probs=context_metadata["topk_weights"],
|
probs=combine_metadata.topk_weights,
|
||||||
)
|
)
|
||||||
if len(self.original_shape) == 3:
|
if len(combine_metadata.restore_shape) == 3:
|
||||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
final_hidden_states = final_hidden_states.view(combine_metadata.restore_shape)
|
||||||
|
|
||||||
# these values are no longer used, so they need to be set to None for memory release.
|
# these values are no longer used, so they need to be set to None for memory release.
|
||||||
return TokenCombineResult(routed_out=final_hidden_states)
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
class TokenDispatcherWithAll2AllV(MoETokenDispatcher[MoEAllToAllCombineMetadata]):
|
||||||
"""
|
"""
|
||||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||||
dispatching on the sequence level instead of token level. The core of this implementation
|
dispatching on the sequence level instead of token level. The core of this implementation
|
||||||
@@ -402,12 +369,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.with_quant = False
|
|
||||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||||
|
|
||||||
self.hidden_shape = None
|
|
||||||
self.hidden_shape_before_permute = None
|
|
||||||
|
|
||||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||||
if self.num_local_experts > 1:
|
if self.num_local_experts > 1:
|
||||||
self.expert_ids_per_ep_rank = torch.tensor(
|
self.expert_ids_per_ep_rank = torch.tensor(
|
||||||
@@ -432,19 +395,12 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
|
|
||||||
def token_dispatch(
|
def token_dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor | None = None,
|
|
||||||
global_redundant_expert_num: int = 0,
|
|
||||||
mc2_mask: torch.Tensor | None = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
with_quant: bool = False,
|
|
||||||
dynamic_eplb: bool = False,
|
|
||||||
pertoken_scale: torch.Tensor | None = None,
|
|
||||||
):
|
):
|
||||||
self.with_quant = with_quant
|
with_quant = token_dispatch_input.quant.is_int_quant
|
||||||
self.hidden_shape = hidden_states.shape
|
hidden_states = token_dispatch_input.hidden_states
|
||||||
|
topk_weights = token_dispatch_input.topk_weights
|
||||||
|
topk_ids = token_dispatch_input.topk_ids
|
||||||
|
|
||||||
(
|
(
|
||||||
permutated_local_input_tokens,
|
permutated_local_input_tokens,
|
||||||
@@ -452,12 +408,13 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
tokens_per_expert,
|
tokens_per_expert,
|
||||||
input_splits,
|
input_splits,
|
||||||
output_splits,
|
output_splits,
|
||||||
num_global_tokens_per_local_expert,
|
|
||||||
global_input_tokens_local_experts_indices,
|
global_input_tokens_local_experts_indices,
|
||||||
|
hidden_shape,
|
||||||
|
hidden_shape_before_permute,
|
||||||
) = self._dispatch_preprocess(hidden_states, topk_ids)
|
) = self._dispatch_preprocess(hidden_states, topk_ids)
|
||||||
|
|
||||||
dynamic_scale_after_all2all = None
|
dynamic_scale_after_all2all = None
|
||||||
if self.with_quant:
|
if with_quant:
|
||||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
|
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
|
||||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||||
dynamic_scale, output_splits, input_splits, self.ep_group
|
dynamic_scale, output_splits, input_splits, self.ep_group
|
||||||
@@ -474,64 +431,66 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
# Postprocess
|
# Postprocess
|
||||||
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
|
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
|
||||||
self._dispatch_postprocess(
|
self._dispatch_postprocess(
|
||||||
global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
|
global_input_tokens,
|
||||||
|
dynamic_scale_after_all2all,
|
||||||
|
global_input_tokens_local_experts_indices,
|
||||||
|
with_quant,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
context_metadata = {
|
return MoETokenDispatchOutput(
|
||||||
"input_splits": input_splits,
|
|
||||||
"output_splits": output_splits,
|
|
||||||
"topk_weights": topk_weights,
|
|
||||||
"reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
|
|
||||||
"reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
|
|
||||||
}
|
|
||||||
|
|
||||||
return TokenDispatchResult(
|
|
||||||
hidden_states=global_input_tokens,
|
hidden_states=global_input_tokens,
|
||||||
dynamic_scale=dynamic_scale_final,
|
dynamic_scale=dynamic_scale_final,
|
||||||
group_list=tokens_per_expert,
|
group_list=tokens_per_expert,
|
||||||
group_list_type=1,
|
group_list_type=1,
|
||||||
context_metadata=context_metadata,
|
combine_metadata=MoEAllToAllCombineMetadata(
|
||||||
|
input_splits=input_splits,
|
||||||
|
output_splits=output_splits,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
reversed_local_input_permutation_mapping=reversed_local_input_permutation_mapping,
|
||||||
|
reversed_global_input_permutation_mapping=reversed_global_input_permutation_mapping,
|
||||||
|
hidden_shape=hidden_shape,
|
||||||
|
hidden_shape_before_permute=hidden_shape_before_permute,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
def token_combine(self, hidden_states, combine_metadata, bias=None):
|
||||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||||
|
|
||||||
# 1. Preprocess using metadata
|
# 1. Preprocess using metadata
|
||||||
hidden_states = self._combine_preprocess(hidden_states, context_metadata)
|
hidden_states = self._combine_preprocess(hidden_states, combine_metadata)
|
||||||
|
|
||||||
# 2. AllToAll
|
# 2. AllToAll
|
||||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
context_metadata["input_splits"],
|
combine_metadata.input_splits,
|
||||||
context_metadata["output_splits"],
|
combine_metadata.output_splits,
|
||||||
self.ep_group,
|
self.ep_group,
|
||||||
)
|
)
|
||||||
handle.wait()
|
handle.wait()
|
||||||
hidden_states.untyped_storage().resize_(0)
|
hidden_states.untyped_storage().resize_(0)
|
||||||
|
|
||||||
# 3. Postprocess using metadata
|
# 3. Postprocess using metadata
|
||||||
output = self._combine_postprocess(permutated_local_input_tokens, context_metadata)
|
output = self._combine_postprocess(permutated_local_input_tokens, combine_metadata)
|
||||||
|
|
||||||
return TokenCombineResult(routed_out=output)
|
return output
|
||||||
|
|
||||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||||
assert self.hidden_shape is not None
|
hidden_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||||
(
|
(
|
||||||
tokens_per_expert,
|
tokens_per_expert,
|
||||||
input_splits,
|
input_splits,
|
||||||
output_splits,
|
output_splits,
|
||||||
num_global_tokens_per_local_expert,
|
|
||||||
global_input_tokens_local_experts_indices,
|
global_input_tokens_local_experts_indices,
|
||||||
|
num_out_tokens,
|
||||||
) = self._preprocess(topk_ids)
|
) = self._preprocess(topk_ids)
|
||||||
|
hidden_shape_before_permute = hidden_states.shape
|
||||||
self.hidden_shape_before_permute = hidden_states.shape
|
|
||||||
|
|
||||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||||
tokens=hidden_states,
|
tokens=hidden_states,
|
||||||
indices=topk_ids,
|
indices=topk_ids,
|
||||||
num_out_tokens=self.num_out_tokens,
|
num_out_tokens=num_out_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -540,15 +499,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
tokens_per_expert,
|
tokens_per_expert,
|
||||||
input_splits,
|
input_splits,
|
||||||
output_splits,
|
output_splits,
|
||||||
num_global_tokens_per_local_expert,
|
|
||||||
global_input_tokens_local_experts_indices,
|
global_input_tokens_local_experts_indices,
|
||||||
|
hidden_shape,
|
||||||
|
hidden_shape_before_permute,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess(self, topk_ids: torch.Tensor):
|
def _preprocess(self, topk_ids: torch.Tensor):
|
||||||
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
|
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
|
||||||
|
|
||||||
ep_size = self.ep_size
|
ep_size = self.ep_size
|
||||||
self.num_out_tokens = topk_ids.numel()
|
num_out_tokens = topk_ids.numel()
|
||||||
|
|
||||||
input_splits = (
|
input_splits = (
|
||||||
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
|
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
|
||||||
@@ -585,19 +545,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
num_tokens_per_local_expert,
|
num_tokens_per_local_expert,
|
||||||
input_splits,
|
input_splits,
|
||||||
output_splits,
|
output_splits,
|
||||||
num_global_tokens_per_local_expert,
|
|
||||||
global_input_tokens_local_experts_indices,
|
global_input_tokens_local_experts_indices,
|
||||||
|
num_out_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _dispatch_postprocess(
|
def _dispatch_postprocess(
|
||||||
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
|
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices, with_quant
|
||||||
):
|
):
|
||||||
# Early return if no local experts or no tokens
|
# Early return if no local experts or no tokens
|
||||||
if self.num_local_experts <= 1:
|
if self.num_local_experts <= 1:
|
||||||
return global_input_tokens, dynamic_scale_after_all2all, None
|
return global_input_tokens, dynamic_scale_after_all2all, None
|
||||||
|
|
||||||
# Handle quantized case
|
# Handle quantized case
|
||||||
if self.with_quant:
|
if with_quant:
|
||||||
assert global_input_tokens_local_experts_indices is not None, (
|
assert global_input_tokens_local_experts_indices is not None, (
|
||||||
"global_input_tokens_local_experts_indices must be provided"
|
"global_input_tokens_local_experts_indices must be provided"
|
||||||
)
|
)
|
||||||
@@ -612,20 +572,26 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
)
|
)
|
||||||
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
||||||
|
|
||||||
def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor:
|
def _combine_preprocess(
|
||||||
|
self, hidden_states: torch.Tensor, combine_metadata: MoEAllToAllCombineMetadata
|
||||||
|
) -> torch.Tensor:
|
||||||
# Unpermutation 2: expert output to AlltoAll input
|
# Unpermutation 2: expert output to AlltoAll input
|
||||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
rev_global = combine_metadata.reversed_global_input_permutation_mapping
|
||||||
rev_global = context_metadata["reversed_global_input_permutation_mapping"]
|
if hidden_states.shape[0] > 0 and self.num_local_experts > 1 and rev_global is not None:
|
||||||
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
|
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor:
|
def _combine_postprocess(
|
||||||
|
self,
|
||||||
|
permutated_local_input_tokens: torch.Tensor,
|
||||||
|
combine_metadata: MoEAllToAllCombineMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
# Unpermutation 1: AlltoAll output to output
|
# Unpermutation 1: AlltoAll output to output
|
||||||
output = torch_npu.npu_moe_token_unpermute(
|
output = torch_npu.npu_moe_token_unpermute(
|
||||||
permuted_tokens=permutated_local_input_tokens,
|
permuted_tokens=permutated_local_input_tokens,
|
||||||
sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32),
|
sorted_indices=combine_metadata.reversed_local_input_permutation_mapping.to(torch.int32),
|
||||||
probs=context_metadata["topk_weights"],
|
probs=combine_metadata.topk_weights,
|
||||||
restore_shape=self.hidden_shape_before_permute,
|
restore_shape=combine_metadata.hidden_shape_before_permute,
|
||||||
)
|
)
|
||||||
output = output.view(self.hidden_shape)
|
output = output.view(combine_metadata.hidden_shape)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -16,24 +16,30 @@
|
|||||||
#
|
#
|
||||||
"""Ascend quantization module.
|
"""Ascend quantization module.
|
||||||
|
|
||||||
This module provides quantization support for Ascend NPU.
|
This module intentionally avoids eager imports so that importing lightweight
|
||||||
|
submodules (for example ``quant_type``) does not trigger heavy registration
|
||||||
Supported quantization tools:
|
paths and circular imports during startup.
|
||||||
- ModelSlim: Use AscendModelSlimConfig
|
|
||||||
- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig
|
|
||||||
|
|
||||||
Public API:
|
|
||||||
- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig
|
|
||||||
- For scheme implementations, import from vllm_ascend.quantization.methods
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# LLM-Compressor (compressed_tensors) quantization config
|
from typing import TYPE_CHECKING, Any
|
||||||
from .compressed_tensors_config import AscendCompressedTensorsConfig
|
|
||||||
|
|
||||||
# ModelSlim quantization config
|
if TYPE_CHECKING:
|
||||||
from .modelslim_config import AscendModelSlimConfig
|
from .compressed_tensors_config import AscendCompressedTensorsConfig
|
||||||
|
from .modelslim_config import AscendModelSlimConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AscendModelSlimConfig",
|
"AscendModelSlimConfig",
|
||||||
"AscendCompressedTensorsConfig",
|
"AscendCompressedTensorsConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
if name == "AscendModelSlimConfig":
|
||||||
|
from .modelslim_config import AscendModelSlimConfig
|
||||||
|
|
||||||
|
return AscendModelSlimConfig
|
||||||
|
if name == "AscendCompressedTensorsConfig":
|
||||||
|
from .compressed_tensors_config import AscendCompressedTensorsConfig
|
||||||
|
|
||||||
|
return AscendCompressedTensorsConfig
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|||||||
@@ -255,28 +255,34 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
log2phy: torch.Tensor | None = None,
|
log2phy: torch.Tensor | None = None,
|
||||||
global_redundant_expert_num=0,
|
global_redundant_expert_num=0,
|
||||||
**kwargs,
|
pertoken_scale: torch.Tensor | None = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.quant_method.apply(
|
return self.quant_method.apply(
|
||||||
layer,
|
layer=layer,
|
||||||
x,
|
x=x,
|
||||||
router_logits,
|
router_logits=router_logits,
|
||||||
top_k,
|
top_k=top_k,
|
||||||
renormalize,
|
renormalize=renormalize,
|
||||||
use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map,
|
expert_map=expert_map,
|
||||||
topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func,
|
scoring_func=scoring_func,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
is_prefill,
|
is_prefill=is_prefill,
|
||||||
enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
**kwargs,
|
pertoken_scale=pertoken_scale,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|||||||
@@ -18,19 +18,11 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend.quantization.quant_type import QuantType
|
||||||
class QuantType(Enum):
|
|
||||||
"""Quantization type enum for MoE schemes."""
|
|
||||||
|
|
||||||
NONE = 0
|
|
||||||
W8A8 = 1
|
|
||||||
W4A8 = 2
|
|
||||||
MXFP8 = 3
|
|
||||||
|
|
||||||
|
|
||||||
class AscendLinearScheme(ABC):
|
class AscendLinearScheme(ABC):
|
||||||
@@ -245,7 +237,10 @@ class AscendMoEScheme(ABC):
|
|||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
log2phy: torch.Tensor | None = None,
|
log2phy: torch.Tensor | None = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
**kwargs,
|
pertoken_scale: Any | None = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward computation for MoE layer.
|
"""Forward computation for MoE layer.
|
||||||
|
|
||||||
@@ -268,7 +263,10 @@ class AscendMoEScheme(ABC):
|
|||||||
enable_force_load_balance: Whether to force load balancing.
|
enable_force_load_balance: Whether to force load balancing.
|
||||||
log2phy: Logical to physical expert mapping.
|
log2phy: Logical to physical expert mapping.
|
||||||
global_redundant_expert_num: Number of redundant experts.
|
global_redundant_expert_num: Number of redundant experts.
|
||||||
**kwargs: Additional keyword arguments.
|
pertoken_scale: Optional per-token activation scale from prepare stage.
|
||||||
|
activation: Expert MLP activation type.
|
||||||
|
apply_router_weight_on_input: Whether to pre-scale hidden states by router weights.
|
||||||
|
mc2_mask: Optional mask used by MC2 dispatch.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor after MoE computation.
|
Output tensor after MoE computation.
|
||||||
|
|||||||
@@ -25,8 +25,9 @@ from vllm.config import get_current_vllm_config
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
|
|
||||||
from .base import AscendMoEScheme
|
from .base import AscendMoEScheme, QuantType
|
||||||
from .registry import register_scheme
|
from .registry import register_scheme
|
||||||
|
|
||||||
|
|
||||||
@@ -103,6 +104,8 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
|
|||||||
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
||||||
"""FusedMoE method for Ascend W4A16."""
|
"""FusedMoE method for Ascend W4A16."""
|
||||||
|
|
||||||
|
quant_type: QuantType = QuantType.W4A16
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.transpose_weight = True
|
self.transpose_weight = True
|
||||||
self.num_bits = 4 # dtype = torch.int4
|
self.num_bits = 4 # dtype = torch.int4
|
||||||
@@ -192,7 +195,10 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
|||||||
enable_force_load_balance: bool = True,
|
enable_force_load_balance: bool = True,
|
||||||
log2phy: torch.Tensor | None = None,
|
log2phy: torch.Tensor | None = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
**kwargs,
|
pertoken_scale: Any | None = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
|
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
|
||||||
"Number of global experts mismatch (excluding redundancy)"
|
"Number of global experts mismatch (excluding redundancy)"
|
||||||
@@ -217,20 +223,26 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
|||||||
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1=layer.w13_weight_packed,
|
hidden_states=x,
|
||||||
w2=layer.w2_weight_packed,
|
topk_weights=topk_weights,
|
||||||
w1_scale=layer.w13_weight_scale,
|
topk_ids=topk_ids,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w1=layer.w13_weight_packed,
|
||||||
w1_offset=layer.w13_weight_offset,
|
w2=layer.w2_weight_packed,
|
||||||
w2_offset=layer.w2_weight_offset,
|
quant_type=self.quant_type,
|
||||||
topk_weights=topk_weights,
|
dynamic_eplb=self.dynamic_eplb,
|
||||||
topk_ids=topk_ids,
|
expert_map=expert_map,
|
||||||
use_int4_w4a16=True,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
expert_map=expert_map,
|
mc2_mask=mc2_mask,
|
||||||
log2phy=log2phy,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
dynamic_eplb=self.dynamic_eplb,
|
log2phy=log2phy,
|
||||||
mc2_mask=kwargs.get("mc2_mask"),
|
pertoken_scale=pertoken_scale,
|
||||||
|
activation=activation,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
w1_offset=layer.w13_weight_offset,
|
||||||
|
w2_offset=layer.w2_weight_offset,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
|||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
|
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
|
||||||
|
|
||||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||||
@@ -343,7 +344,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
log2phy: torch.Tensor | None = None,
|
log2phy: torch.Tensor | None = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
**kwargs,
|
pertoken_scale: torch.Tensor | None = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
|
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
|
||||||
"Number of global experts mismatch (excluding redundancy)"
|
"Number of global experts mismatch (excluding redundancy)"
|
||||||
@@ -377,20 +381,26 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1=[layer.w13_weight],
|
hidden_states=x,
|
||||||
w2=[layer.w2_weight],
|
topk_weights=topk_weights,
|
||||||
w1_scale=[layer.w13_weight_scale],
|
topk_ids=topk_ids,
|
||||||
w2_scale=[layer.w2_weight_scale],
|
w1=[layer.w13_weight],
|
||||||
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
|
w2=[layer.w2_weight],
|
||||||
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
|
quant_type=self.quant_type,
|
||||||
topk_weights=topk_weights,
|
dynamic_eplb=self.dynamic_eplb,
|
||||||
topk_ids=topk_ids,
|
expert_map=expert_map,
|
||||||
use_int4_w4a8=True,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
expert_map=expert_map,
|
mc2_mask=mc2_mask,
|
||||||
log2phy=log2phy,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
dynamic_eplb=self.dynamic_eplb,
|
log2phy=log2phy,
|
||||||
mc2_mask=kwargs.get("mc2_mask"),
|
pertoken_scale=pertoken_scale,
|
||||||
|
activation=activation,
|
||||||
|
w1_scale=[layer.w13_weight_scale],
|
||||||
|
w2_scale=[layer.w2_weight_scale],
|
||||||
|
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
|
||||||
|
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
|||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
|
||||||
|
|
||||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||||
@@ -182,7 +183,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
log2phy: torch.Tensor | None = None,
|
log2phy: torch.Tensor | None = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
pertoken_scale: Any | None = None,
|
pertoken_scale: Any | None = None,
|
||||||
**kwargs,
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||||
@@ -249,19 +252,24 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
|
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
|
||||||
|
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
pertoken_scale=pertoken_scale,
|
hidden_states=x,
|
||||||
w1=w1,
|
topk_weights=topk_weights,
|
||||||
w1_scale=w1_scale,
|
topk_ids=topk_ids,
|
||||||
w2=w2,
|
w1=w1,
|
||||||
w2_scale=w2_scale,
|
w2=w2,
|
||||||
topk_weights=topk_weights,
|
quant_type=self.quant_type,
|
||||||
topk_ids=topk_ids,
|
dynamic_eplb=self.dynamic_eplb,
|
||||||
use_int8_w8a8=True,
|
expert_map=expert_map,
|
||||||
expert_map=expert_map,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
log2phy=log2phy,
|
mc2_mask=mc2_mask,
|
||||||
dynamic_eplb=self.dynamic_eplb,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
mc2_mask=kwargs.get("mc2_mask"),
|
log2phy=log2phy,
|
||||||
|
pertoken_scale=pertoken_scale,
|
||||||
|
activation=activation,
|
||||||
|
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||||
|
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||||
final_hidden_states += zero_expert_result
|
final_hidden_states += zero_expert_result
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from vllm_ascend.device.mxfp_compat import (
|
|||||||
ensure_mxfp8_moe_available,
|
ensure_mxfp8_moe_available,
|
||||||
)
|
)
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||||
|
|
||||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||||
from .registry import register_scheme
|
from .registry import register_scheme
|
||||||
@@ -170,7 +171,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
enable_force_load_balance: bool = True,
|
enable_force_load_balance: bool = True,
|
||||||
log2phy: torch.Tensor = None,
|
log2phy: torch.Tensor = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
**kwargs,
|
pertoken_scale: Any | None = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
mc2_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
expected = global_num_experts - global_redundant_expert_num
|
expected = global_num_experts - global_redundant_expert_num
|
||||||
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
|
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
|
||||||
@@ -198,23 +202,29 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
fused_experts_input=build_fused_experts_input(
|
||||||
w1=layer.w13_weight,
|
hidden_states=x,
|
||||||
w1_scale=layer.w13_weight_scale,
|
topk_weights=topk_weights,
|
||||||
w2=layer.w2_weight,
|
topk_ids=topk_ids,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w1=layer.w13_weight,
|
||||||
topk_weights=topk_weights,
|
w2=layer.w2_weight,
|
||||||
topk_ids=topk_ids,
|
quant_type=self.quant_type,
|
||||||
use_int8_w8a8=False,
|
dynamic_eplb=self.dynamic_eplb,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
log2phy=log2phy,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
dynamic_eplb=self.dynamic_eplb,
|
mc2_mask=mc2_mask,
|
||||||
mc2_mask=kwargs.get("mc2_mask"),
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_mxfp_quant=True,
|
log2phy=log2phy,
|
||||||
act_quant_type=torch.float8_e4m3fn,
|
pertoken_scale=pertoken_scale,
|
||||||
weight_quant_type=torch.float8_e4m3fn,
|
activation=activation,
|
||||||
scale_type=FLOAT8_E8M0FNU_DTYPE,
|
mxfp_act_quant_type=torch.float8_e4m3fn,
|
||||||
per_token_scale_type=FLOAT8_E8M0FNU_DTYPE,
|
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
||||||
|
mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||||
|
mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||||
|
mxfp_use_bf16=(x.dtype == torch.bfloat16),
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
|
|||||||
33
vllm_ascend/quantization/quant_type.py
Normal file
33
vllm_ascend/quantization/quant_type.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""Shared quantization enum definitions.
|
||||||
|
|
||||||
|
Keep this module lightweight and side-effect free so core runtime modules can
|
||||||
|
import QuantType without triggering heavy quantization package initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class QuantType(Enum):
|
||||||
|
"""Quantization type enum for MoE schemes."""
|
||||||
|
|
||||||
|
NONE = 0
|
||||||
|
W8A8 = 1
|
||||||
|
W4A8 = 2
|
||||||
|
MXFP8 = 3
|
||||||
|
W4A16 = 4
|
||||||
Reference in New Issue
Block a user