[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -114,7 +114,16 @@ def mock_distributed():
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group):
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
|
||||
|
||||
@@ -205,7 +214,8 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_moe(mock_distributed, base_config):
|
||||
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = CustomDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
|
||||
@@ -18,16 +18,18 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from vllm_ascend.ascend_forward_context import get_fused_moe_state
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
|
||||
def mock_ep_group(mocker):
|
||||
def mock_ep_and_mc2_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.rank = 0
|
||||
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=MagicMock(
|
||||
attn_metadata=MagicMock(max_num_tokens_across_dp=10),
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
|
||||
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)):
|
||||
yield
|
||||
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
||||
torch.randn(16, 2))), \
|
||||
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
||||
torch.randn(16, 2))):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -237,11 +247,16 @@ class TestAscendFusedMoe:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
@@ -288,15 +303,20 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
|
||||
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
|
||||
2 test use_select_experts and fused_experts
|
||||
3 test use select_gating_topk_softmax_experts and fused_experts
|
||||
4 test use select_experts and fused_experts_with_all2all_buffer
|
||||
"""
|
||||
global_num_experts, ep_size, select_softmax = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||
with patch(
|
||||
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
|
||||
select_softmax):
|
||||
select_softmax), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
@@ -309,7 +329,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=False)
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
@@ -327,8 +347,13 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
4 test use_select_experts and fused_experts
|
||||
"""
|
||||
ep_size, alltoall_buffer = others_param
|
||||
is_prefill = False
|
||||
forward_context = MagicMock(
|
||||
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
|
||||
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
||||
alltoall_buffer):
|
||||
alltoall_buffer), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
@@ -347,7 +372,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
renormalize=True,
|
||||
global_num_experts=128,
|
||||
expert_map=expert_map,
|
||||
is_prefill=False)
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 16 or ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
|
||||
Reference in New Issue
Block a user