[1/N][refactor] torchair fused_moe refactor (#2438)

### What this PR does / why we need it?
Move torchair related fused_moe section into torchair_fused_moe to make
the code clear. Next step we'll remove all torchair related code outside
of torchair_fused_moe .

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
vLLM version: v0.10.0
vLLM main:
08d5f7113a

- vLLM version: v0.10.1.1
- vLLM main:
170e8ea9ea

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-08-25 15:46:10 +08:00
committed by GitHub
parent 334c44613a
commit 0f81e032f0
5 changed files with 1974 additions and 6 deletions

View File

@@ -112,7 +112,7 @@ def mock_distributed():
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
@@ -227,8 +227,9 @@ def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)