[1/N][refactor] torchair fused_moe refactor (#2438)
### What this PR does / why we need it? Move torchair related fused_moe section into torchair_fused_moe to make the code clear. Next step we'll remove all torchair related code outside of torchair_fused_moe . ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: v0.10.0 vLLM main:08d5f7113a- vLLM version: v0.10.1.1 - vLLM main:170e8ea9eaSigned-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -112,7 +112,7 @@ def mock_distributed():
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||
@@ -227,8 +227,9 @@ def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user