diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py index 13e1fa3..6e3da1f 100644 --- a/tests/e2e/multicard/test_qwen3_moe.py +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -66,7 +66,6 @@ def test_models_distributed_Qwen3_MOE_W8A8(): max_model_len=8192, tensor_parallel_size=2, quantization="ascend", - enforce_eager=True, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index f8dde63..c6da287 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -29,6 +29,7 @@ import torch_npu from vllm.model_executor.layers.activation import SiluAndMul from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather NUM_EXPERTS = [8, 64] @@ -165,6 +166,87 @@ def test_token_dispatcher_with_all_gather( torch.npu.reset_peak_memory_stats() +@pytest.mark.parametrize("m", [1, 33, 64]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICE) +def test_token_dispatcher_with_all_gather_quant( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + device: str, +): + context_mock = MagicMock() + context_mock.fused_moe_state = 0 + with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context", + return_value=context_mock): + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) + w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype) + w2 = torch.randn((e, n, k), device=device, dtype=torch.int8) + w2_scale = torch.empty((e, k), device=device, dtype=dtype) + + score = torch.randn((m, e), device=device, dtype=dtype) + expert_map = None + local_e = e + + score = torch.softmax(score, dim=-1, dtype=dtype) + topk_weights, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.to(torch.int32) + row_idx = (torch.arange( + 0, + m * topk, + device=device, + dtype=torch.int32, + ).view(topk, -1).permute(1, 0).contiguous()) + + dispatcher_kwargs = { + "num_experts": e, + "top_k": topk, + "num_local_experts": local_e, + } + dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) + + apply_router_weight_on_input = False + dispatch_output = dispatcher.token_dispatch( + hidden_states=a, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + with_quant=True) + + sorted_hidden_states = dispatch_output["hidden_states"] + group_list = dispatch_output["group_list"] + group_list_type = dispatch_output.get("group_list_type", 1) + dynamic_scale = dispatch_output["dynamic_scale"] + + expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + group_list_type=group_list_type, + dynamic_scale=dynamic_scale, + with_quant=True) + combined_output = dispatcher.token_combine(hidden_states=expert_output, + bias=None) + assert combined_output.shape == (m, k) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + @pytest.mark.parametrize("m", [1, 33, 64]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 53b2fa9..1416d41 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -221,7 +221,7 @@ class TestTokenDispatcherWithAllGather(TestBase): self.assertEqual(results["group_list_type"], 1) - def test_token_dispatch_with_quant(self): + def test_token_dispatch_without_quant(self): kwargs = { "apply_router_weight_on_input": False, "top_k": 2, @@ -241,6 +241,32 @@ class TestTokenDispatcherWithAllGather(TestBase): self.assertEqual(results["group_list_type"], 1) + def test_token_dispatch_with_quant(self): + kwargs = { + "apply_router_weight_on_input": False, + "top_k": 2, + "max_num_tokens": 100, + "ep_size": 2, + "num_experts": 128, + } + self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) + + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + + results = self.dispatcher_quant.token_dispatch(hidden_states, + topk_weights, + topk_ids, + self.row_idx, + None, + with_quant=True) + + self.assertIsNotNone(results["hidden_states"]) + self.assertIsNotNone(results["group_list"]) + self.assertIsNotNone(results["dynamic_scale"]) + self.assertEqual(results["group_list_type"], 1) + def test_token_combine_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index a5c5566..72a1c34 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -367,7 +367,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): last_expert_idx = self.num_experts_local global_num_experts = self.num_experts_local - sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = ( + sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = ( torch_npu.npu_moe_init_routing_v2( hidden_states, topk_ids, @@ -376,7 +376,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expert_tokens_num_type=1, expert_tokens_num_flag=True, active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=-1, + quant_mode=1 if self.with_quant else -1, )) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode @@ -384,6 +384,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, "group_list": expert_tokens, + "dynamic_scale": pertoken_scale if self.with_quant else None, } def token_combine(self,