[Refactor] Adjustments to moe_comm_method selection process (#3001)
### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -17,56 +17,7 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
|
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||||
|
|
||||||
|
|
||||||
class TestFusedExpertsMoGE(TestBase):
|
|
||||||
|
|
||||||
def test_fused_experts_moge(self):
|
|
||||||
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
|
|
||||||
patch('torch_npu.npu_swiglu') as mock_swiglu, \
|
|
||||||
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
|
|
||||||
|
|
||||||
mock_is_310p.return_value = False
|
|
||||||
|
|
||||||
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
|
|
||||||
torch.randn(x[0].shape[0], weight[0].shape[1])
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_swiglu.side_effect = lambda x: x
|
|
||||||
|
|
||||||
hidden_states = torch.randn(4, 128)
|
|
||||||
w1 = torch.randn(4, 256, 128)
|
|
||||||
w2 = torch.randn(4, 128, 128)
|
|
||||||
topk_weights = torch.rand(4, 1)
|
|
||||||
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
|
|
||||||
top_k = 1
|
|
||||||
global_num_experts = 4
|
|
||||||
|
|
||||||
moe_parallel_config = type(
|
|
||||||
'MockConfig', (), {
|
|
||||||
'ep_size': 1,
|
|
||||||
'tp_size': 1,
|
|
||||||
'dp_size': 1,
|
|
||||||
'tp_rank': 0,
|
|
||||||
'dp_rank': 0,
|
|
||||||
'ep_rank': 0,
|
|
||||||
'use_ep': True
|
|
||||||
})()
|
|
||||||
|
|
||||||
output = fused_experts_moge(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
moe_parallel_config=moe_parallel_config,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
top_k=top_k,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
apply_router_weight_on_input=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(output.shape, (4, 128))
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadWeight(TestBase):
|
class TestLoadWeight(TestBase):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from pytest_mock import MockerFixture
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||||
AscendUnquantizedFusedMoEMethod)
|
AscendUnquantizedFusedMoEMethod)
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
@@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format):
|
|||||||
return weight_data
|
return weight_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||||
|
mock_hf_config = MagicMock()
|
||||||
|
mock_hf_config.model_type = "llama"
|
||||||
|
|
||||||
|
mock_model_config = MagicMock()
|
||||||
|
mock_model_config.hf_config = mock_hf_config
|
||||||
|
|
||||||
|
mock_vllm_config = MagicMock()
|
||||||
|
mock_vllm_config.model_config = mock_model_config
|
||||||
|
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
|
||||||
|
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||||
|
mock_vllm_config.model_config.max_model_len = 2048
|
||||||
|
|
||||||
|
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||||
|
return_value=mock_vllm_config)
|
||||||
|
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||||
|
return_value=mock_vllm_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_dist_env(mocker: MockerFixture):
|
def mock_dist_env(mocker: MockerFixture):
|
||||||
mock_moe_comm_method = MagicMock()
|
mock_moe_comm_method = MagicMock()
|
||||||
@@ -74,7 +95,7 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
|
|
||||||
mock_forward_context_obj = MagicMock(
|
mock_forward_context_obj = MagicMock(
|
||||||
moe_comm_method=mock_moe_comm_method,
|
moe_comm_method=mock_moe_comm_method,
|
||||||
moe_comm_method_name="mc2commimpl",
|
moe_comm_type=MoECommType.MC2,
|
||||||
max_tokens_across_dp=10,
|
max_tokens_across_dp=10,
|
||||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||||
@@ -104,12 +125,6 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
|
||||||
return_value=MagicMock(
|
|
||||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
|
||||||
scheduler_config=MagicMock(max_num_seqs=4),
|
|
||||||
model_config=MagicMock(max_model_len=2048)
|
|
||||||
)), \
|
|
||||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
@@ -501,7 +516,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
mock_get_forward_context):
|
mock_get_forward_context):
|
||||||
|
|
||||||
mock_forward_context = MagicMock()
|
mock_forward_context = MagicMock()
|
||||||
mock_forward_context.moe_comm_method_name = "mc2commimpl"
|
mock_forward_context.moe_comm_type = MoECommType.MC2
|
||||||
mock_get_forward_context.return_value = mock_forward_context
|
mock_get_forward_context.return_value = mock_forward_context
|
||||||
|
|
||||||
mock_is_310p.return_value = False
|
mock_is_310p.return_value = False
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
self.moe_config.dp_group = MagicMock()
|
self.moe_config.dp_group = MagicMock()
|
||||||
self.moe_config.num_global_redundant_experts = 0
|
self.moe_config.num_global_redundant_experts = 0
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||||
@@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase):
|
|||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||||
mock_prepare_finalize,
|
mock_prepare_finalize,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context,
|
||||||
|
mock_get_current_vllm_config):
|
||||||
|
# Mock vLLM config
|
||||||
|
mock_get_current_vllm_config.return_value = MagicMock()
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "all_gather"
|
mock_context.moe_comm_method = "all_gather"
|
||||||
@@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase):
|
|||||||
comm_impl.finalize(h_out, reduce_results=True)
|
comm_impl.finalize(h_out, reduce_results=True)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||||
)
|
)
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context,
|
||||||
|
mock_get_current_vllm_config):
|
||||||
|
# Mock vLLM config
|
||||||
|
mock_get_current_vllm_config.return_value = MagicMock()
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "mc2"
|
mock_context.moe_comm_method = "mc2"
|
||||||
@@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
comm_impl.finalize(h_out, reduce_results=True)
|
comm_impl.finalize(h_out, reduce_results=True)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||||
@@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase):
|
|||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||||
mock_prepare_finalize,
|
mock_prepare_finalize,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context,
|
||||||
|
mock_get_current_vllm_config):
|
||||||
|
# Mock vLLM config
|
||||||
|
mock_get_current_vllm_config.return_value = MagicMock()
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "alltoall"
|
mock_context.moe_comm_method = "alltoall"
|
||||||
@@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, False, None)
|
hidden_states, router_logits, False, False, False, None)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||||
@@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase):
|
|||||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||||
mock_token_dispatcher, mock_prepare_finalize,
|
mock_token_dispatcher, mock_prepare_finalize,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context,
|
||||||
|
mock_get_current_vllm_config):
|
||||||
|
# Mock vLLM config
|
||||||
|
mock_get_current_vllm_config.return_value = MagicMock()
|
||||||
|
|
||||||
# Mock forward context
|
# Mock forward context
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
mock_context.moe_comm_method = "all_gather"
|
mock_context.moe_comm_method = "all_gather"
|
||||||
|
|||||||
@@ -48,18 +48,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
|||||||
output_size = 56
|
output_size = 56
|
||||||
group_size = 2
|
group_size = 2
|
||||||
|
|
||||||
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
|
||||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||||
@patch('torch.distributed.get_rank', return_value=0)
|
@patch('torch.distributed.get_rank', return_value=0)
|
||||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||||
get_current_vllm_config):
|
get_current_vllm_config, mock_get_ascend_config):
|
||||||
|
# Mock ascend config
|
||||||
|
mock_ascend_config = Mock()
|
||||||
|
mock_ascend_config.dynamic_eplb = False
|
||||||
|
mock_get_ascend_config.return_value = mock_ascend_config
|
||||||
|
|
||||||
mock_vllm_config = Mock()
|
mock_vllm_config = Mock()
|
||||||
mock_vllm_config.quant_config = Mock(quant_description={
|
mock_vllm_config.quant_config = Mock(quant_description={
|
||||||
"group_size": self.group_size,
|
"group_size": self.group_size,
|
||||||
"version": "0.0.0"
|
"version": "0.0.0"
|
||||||
})
|
})
|
||||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||||
|
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
|
||||||
|
max_model_len=2048,
|
||||||
|
enable_chunked_prefill=False)
|
||||||
get_current_vllm_config.return_value = mock_vllm_config
|
get_current_vllm_config.return_value = mock_vllm_config
|
||||||
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.utils import AscendSocVersion
|
from vllm_ascend.utils import AscendSocVersion
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
@@ -24,21 +25,21 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|||||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
||||||
[
|
[
|
||||||
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
||||||
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
|
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
|
||||||
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
|
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),
|
||||||
|
|
||||||
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
||||||
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
|
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
||||||
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
|
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
||||||
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
|
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
|
||||||
|
|
||||||
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
||||||
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
|
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
|
||||||
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
|
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),
|
||||||
|
|
||||||
# Case 4: A3 SOC
|
# Case 4: A3 SOC
|
||||||
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
|
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
|
||||||
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
|
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL),
|
||||||
])
|
])
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||||
|
|||||||
@@ -22,6 +22,13 @@ class FusedMoEState(Enum):
|
|||||||
All2AllSeq = 5
|
All2AllSeq = 5
|
||||||
|
|
||||||
|
|
||||||
|
class MoECommType(Enum):
|
||||||
|
ALLGATHER = 0
|
||||||
|
MC2 = 1
|
||||||
|
ALLTOALL = 2
|
||||||
|
NAIVE_MULTICAST = 3
|
||||||
|
|
||||||
|
|
||||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||||
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||||
is_deepseek_v3_r1: bool):
|
is_deepseek_v3_r1: bool):
|
||||||
@@ -52,7 +59,7 @@ def set_ascend_forward_context(
|
|||||||
with_prefill: bool = True,
|
with_prefill: bool = True,
|
||||||
in_profile_run: bool = False,
|
in_profile_run: bool = False,
|
||||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||||
moe_comm_method: str = "",
|
moe_comm_type: Optional[MoECommType] = None,
|
||||||
num_actual_tokens: Optional[int] = None,
|
num_actual_tokens: Optional[int] = None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||||
@@ -72,7 +79,11 @@ def set_ascend_forward_context(
|
|||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
):
|
):
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
|
||||||
|
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
|
||||||
|
forward_context.moe_comm_type = moe_comm_type
|
||||||
|
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||||
|
|
||||||
forward_context.with_prefill = with_prefill
|
forward_context.with_prefill = with_prefill
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
ep_size = (get_ep_group().world_size if
|
ep_size = (get_ep_group().world_size if
|
||||||
|
|||||||
@@ -23,106 +23,23 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
|||||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.config import \
|
|
||||||
FusedMoEParallelConfig # isort: skip
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||||
determine_default_log2phy_map)
|
determine_default_log2phy_map)
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||||
AlltoAllCommImpl, MC2CommImpl,
|
|
||||||
NaiveMulticastCommImpl)
|
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||||
|
|
||||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_moge(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
moe_parallel_config: FusedMoEParallelConfig,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
global_num_experts: int,
|
|
||||||
expert_map: torch.Tensor = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
|
||||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
|
||||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
|
||||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
|
||||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
|
||||||
top_k: Number of experts to select.
|
|
||||||
expert_map: Expert mapping of shape (num_experts,).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
hidden_states: Hidden states after routing.
|
|
||||||
"""
|
|
||||||
ep_size = moe_parallel_config.ep_size
|
|
||||||
local_num_experts = global_num_experts // ep_size
|
|
||||||
local_num_group = top_k // ep_size
|
|
||||||
|
|
||||||
bsz, _ = hidden_states.shape
|
|
||||||
flatten_topk_ids = topk_ids.view(-1)
|
|
||||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
|
||||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
|
||||||
sorted_hidden_states = hidden_states.index_select(
|
|
||||||
0, sorted_topk_ids // local_num_group)
|
|
||||||
|
|
||||||
experts_id = torch.arange(0,
|
|
||||||
local_num_experts,
|
|
||||||
dtype=topk_ids.dtype,
|
|
||||||
device=topk_ids.device)
|
|
||||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
|
||||||
torch.float32).sum(0)
|
|
||||||
topk_scales = topk_weights.view(-1).index_select(
|
|
||||||
0, sorted_topk_ids).unsqueeze(-1)
|
|
||||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
|
||||||
|
|
||||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
|
||||||
x=[sorted_hidden_states],
|
|
||||||
weight=[w1],
|
|
||||||
split_item=2,
|
|
||||||
group_list_type=0,
|
|
||||||
group_type=0,
|
|
||||||
group_list=group_list,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
if is_310p():
|
|
||||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
|
||||||
torch.float16)
|
|
||||||
else:
|
|
||||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
|
||||||
gate_up_out *= topk_scales
|
|
||||||
|
|
||||||
down_out_list = torch_npu.npu_grouped_matmul(
|
|
||||||
x=[gate_up_out],
|
|
||||||
weight=[w2],
|
|
||||||
split_item=2,
|
|
||||||
group_list_type=0,
|
|
||||||
group_type=0,
|
|
||||||
group_list=group_list,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
|
||||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
|
||||||
final_hidden_states = unsorted_hidden_states.reshape(
|
|
||||||
bsz, top_k // ep_size, -1).sum(1)
|
|
||||||
|
|
||||||
return final_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||||
|
|
||||||
@@ -178,20 +95,6 @@ def forward_oot(
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
global_num_experts=global_num_experts)
|
global_num_experts=global_num_experts)
|
||||||
|
|
||||||
if topk_ids.shape[1] < top_k or is_310p():
|
|
||||||
assert global_num_experts is not None
|
|
||||||
return fused_experts_moge(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
moe_parallel_config=self.moe.moe_parallel_config,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
top_k=top_k,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
return moe_comm_method.fused_experts(hidden_states=x,
|
return moe_comm_method.fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -277,13 +180,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||||
|
|
||||||
for method in {
|
setup_moe_comm_method(self.moe_config)
|
||||||
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
|
|
||||||
NaiveMulticastCommImpl
|
|
||||||
}:
|
|
||||||
setattr(
|
|
||||||
self, method.__name__.lower(),
|
|
||||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
def update_expert_map(self, new_expert_map):
|
||||||
self.expert_map = new_expert_map
|
self.expert_map = new_expert_map
|
||||||
@@ -307,8 +204,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
outputs since each rank only has partial outputs.
|
outputs since each rank only has partial outputs.
|
||||||
"""
|
"""
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
else:
|
else:
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
@@ -318,10 +215,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
|
||||||
|
|
||||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
|
||||||
|
|
||||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states, router_logits=router_logits)
|
hidden_states=hidden_states, router_logits=router_logits)
|
||||||
|
|
||||||
@@ -449,8 +342,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
|
|
||||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
|
|
||||||
_, fused_out = AscendFusedMoE.forward(
|
_, fused_out = AscendFusedMoE.forward(
|
||||||
|
|||||||
@@ -41,9 +41,7 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|||||||
determine_default_log2phy_map)
|
determine_default_log2phy_map)
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||||
AlltoAllCommImpl, MC2CommImpl,
|
|
||||||
NaiveMulticastCommImpl)
|
|
||||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||||
get_all_reduce_merge_state,
|
get_all_reduce_merge_state,
|
||||||
@@ -339,13 +337,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.moe_config.mc2_group = get_mc2_group()
|
self.moe_config.mc2_group = get_mc2_group()
|
||||||
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
|
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
|
||||||
|
|
||||||
for method in {
|
setup_moe_comm_method(self.moe_config)
|
||||||
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
|
|
||||||
NaiveMulticastCommImpl
|
|
||||||
}:
|
|
||||||
setattr(
|
|
||||||
self, method.__name__.lower(),
|
|
||||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
def update_expert_map(self, new_expert_map):
|
||||||
self.expert_map = new_expert_map
|
self.expert_map = new_expert_map
|
||||||
@@ -360,22 +352,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if self.moe_load is not None:
|
if self.moe_load is not None:
|
||||||
self.moe_load.zero_()
|
self.moe_load.zero_()
|
||||||
|
|
||||||
def naive_multicast(self, x: torch.Tensor,
|
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
|
||||||
assert (len(x.shape) == 2)
|
|
||||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype)
|
|
||||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
|
||||||
self.dp_rank - 1]
|
|
||||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
|
||||||
buffer[start:end, :].copy_(x)
|
|
||||||
for idx in range(self.dp_size):
|
|
||||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
|
||||||
end = cu_tokens_across_dp_cpu[idx]
|
|
||||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@@ -412,9 +388,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||||
replace_allreduce = True
|
replace_allreduce = True
|
||||||
|
|
||||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
|
||||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
|
||||||
|
|
||||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@@ -13,14 +13,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||||
@@ -28,13 +31,31 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
|||||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
||||||
TokenDispatcherWithAllGather,
|
TokenDispatcherWithAllGather,
|
||||||
TokenDispatcherWithMC2)
|
TokenDispatcherWithMC2,
|
||||||
|
TokenDispatcherWithMoge)
|
||||||
|
|
||||||
|
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_comm_method(
|
||||||
|
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
|
||||||
|
return _MoECommMethods.get(moe_comm_type)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_moe_comm_method(moe_config):
|
||||||
|
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||||
|
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||||
|
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||||
|
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||||
|
moe_config)
|
||||||
|
|
||||||
|
|
||||||
class MoECommMethod(ABC):
|
class MoECommMethod(ABC):
|
||||||
"""Base class for MoE communication methods."""
|
"""Base class for MoE communication methods."""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
|
self.model_type = get_current_vllm_config(
|
||||||
|
).model_config.hf_config.model_type
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
self.mc2_mask = None
|
self.mc2_mask = None
|
||||||
|
|
||||||
@@ -113,8 +134,8 @@ class MoECommMethod(ABC):
|
|||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
with_quant=use_int8_w8a8 or use_int4_w4a8)
|
with_quant=use_int8_w8a8 or use_int4_w4a8)
|
||||||
|
|
||||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \
|
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
||||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"]
|
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
||||||
|
|
||||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@@ -126,6 +147,7 @@ class MoECommMethod(ABC):
|
|||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
w1_scale_bias=w1_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w2_scale_bias=w2_scale_bias,
|
||||||
|
topk_scales=topk_scales,
|
||||||
with_quant=use_int8_w8a8
|
with_quant=use_int8_w8a8
|
||||||
or use_int4_w4a8,
|
or use_int4_w4a8,
|
||||||
fusion=use_int8_w8a8,
|
fusion=use_int8_w8a8,
|
||||||
@@ -170,94 +192,21 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_token_dispatcher(self):
|
def _get_token_dispatcher(self):
|
||||||
return TokenDispatcherWithAllGather(
|
if self.model_type == "PanguProMoE":
|
||||||
top_k=self.moe_config.experts_per_token,
|
return TokenDispatcherWithMoge(
|
||||||
num_experts=self.moe_config.num_experts,
|
top_k=self.moe_config.experts_per_token,
|
||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_experts=self.moe_config.num_experts,
|
||||||
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
else:
|
||||||
|
return TokenDispatcherWithAllGather(
|
||||||
|
top_k=self.moe_config.experts_per_token,
|
||||||
|
num_experts=self.moe_config.num_experts,
|
||||||
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_fused_moe_prepare_finalize(self):
|
def _get_fused_moe_prepare_finalize(self):
|
||||||
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||||
|
|
||||||
|
|
||||||
class NativeAllGatherCommImpl(AllGatherCommImpl):
|
|
||||||
"""This implementation should be compatible with all scenarios.
|
|
||||||
|
|
||||||
Note that this implementation purely consists of native PyTorch ops
|
|
||||||
and does not use any NPU-specific ops. So the performance may not be optimal.
|
|
||||||
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def permute(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
expert_map: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
apply_a8_quantization: bool,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
|
||||||
num_tokens = hidden_states.shape[0]
|
|
||||||
|
|
||||||
# Generate token indices and flatten
|
|
||||||
token_indices = torch.arange(num_tokens,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
token_indices = (token_indices.unsqueeze(1).expand(
|
|
||||||
-1, self.moe_config.experts_per_token).reshape(-1))
|
|
||||||
|
|
||||||
# Flatten token-to-expert mappings and map to local experts
|
|
||||||
weights_flat = topk_weights.view(-1)
|
|
||||||
experts_flat = topk_ids.view(-1)
|
|
||||||
local_experts_flat = (expert_map[experts_flat]
|
|
||||||
if expert_map is not None else experts_flat)
|
|
||||||
|
|
||||||
# Filter valid token-expert pairs
|
|
||||||
mask = local_experts_flat != -1
|
|
||||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
||||||
# So we need to filter out invalid tokens by zeroing their weights.
|
|
||||||
# This is a workaround and should be removed after the issue is fixed
|
|
||||||
filtered_weights = torch.where(mask, weights_flat,
|
|
||||||
torch.zeros_like(weights_flat)).to(
|
|
||||||
topk_weights.dtype)
|
|
||||||
filtered_experts = torch.where(
|
|
||||||
mask,
|
|
||||||
local_experts_flat,
|
|
||||||
torch.full_like(local_experts_flat, num_experts),
|
|
||||||
).to(topk_ids.dtype)
|
|
||||||
|
|
||||||
# Sort by local expert IDs
|
|
||||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
|
||||||
self.sorted_token_indices = token_indices[sort_indices]
|
|
||||||
self.sorted_weights = filtered_weights[sort_indices]
|
|
||||||
|
|
||||||
# Compute token counts with minlength of num_experts
|
|
||||||
# This is equivalent to but faster than:
|
|
||||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
||||||
token_counts = torch.zeros(num_experts + 1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
||||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
|
||||||
expert_tokens = token_counts[:num_experts]
|
|
||||||
|
|
||||||
# Rearrange hidden_states
|
|
||||||
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
|
||||||
|
|
||||||
group_list_type = 1 # `count` mode
|
|
||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
|
||||||
|
|
||||||
def unpermute(self, mlp_output: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor) -> None:
|
|
||||||
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
|
||||||
|
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
|
||||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
|
||||||
mlp_output)
|
|
||||||
|
|
||||||
hidden_states[:] = final_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class MC2CommImpl(MoECommMethod):
|
class MC2CommImpl(MoECommMethod):
|
||||||
"""This implementation is for the scenarios listed below:
|
"""This implementation is for the scenarios listed below:
|
||||||
1. `enable_expert_parallel=True`.
|
1. `enable_expert_parallel=True`.
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import torch_npu
|
|||||||
from torch.nn.functional import pad
|
from torch.nn.functional import pad
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||||
|
|
||||||
|
|
||||||
@@ -76,7 +77,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
bias1, bias2 = None, None
|
bias1, bias2 = None, None
|
||||||
_output_dtype = w2_scale.dtype
|
_output_dtype = w2_scale.dtype
|
||||||
|
|
||||||
is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl"
|
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||||
if w1_scale_bias is None and is_mc2:
|
if w1_scale_bias is None and is_mc2:
|
||||||
if w1_scale.dtype != torch.float32:
|
if w1_scale.dtype != torch.float32:
|
||||||
w1_scale = w1_scale.to(torch.float32)
|
w1_scale = w1_scale.to(torch.float32)
|
||||||
|
|||||||
@@ -377,14 +377,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
|
|
||||||
|
|
||||||
# mypy: disable-error-code="override"
|
# mypy: disable-error-code="override"
|
||||||
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.apply_router_weight_on_input = False
|
self.apply_router_weight_on_input = False
|
||||||
self.local_ep = 1
|
self.local_num_experts = self.num_experts // self.ep_size
|
||||||
self.local_num_experts = self.num_experts // self.local_ep
|
self.local_num_group = self.top_k // self.ep_size
|
||||||
self.local_num_group = self.top_k // self.local_ep
|
|
||||||
self.bsz = None
|
self.bsz = None
|
||||||
|
|
||||||
def token_dispatch(self,
|
def token_dispatch(self,
|
||||||
@@ -401,17 +400,6 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
|||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False):
|
with_quant: bool = False):
|
||||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
|
||||||
if self.apply_router_weight_on_input:
|
|
||||||
assert (topk_weights.dim() == 2
|
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
|
||||||
_, topk = topk_weights.shape
|
|
||||||
assert (
|
|
||||||
topk == 1
|
|
||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
|
||||||
hidden_states = hidden_states * \
|
|
||||||
topk_weights.to(hidden_states.dtype)
|
|
||||||
|
|
||||||
self.bsz, _ = hidden_states.shape
|
self.bsz, _ = hidden_states.shape
|
||||||
flatten_topk_ids = topk_ids.view(-1)
|
flatten_topk_ids = topk_ids.view(-1)
|
||||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||||
@@ -445,7 +433,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
|||||||
unsorted_hidden_states = hidden_states.index_select(
|
unsorted_hidden_states = hidden_states.index_select(
|
||||||
0, unsorted_topk_ids)
|
0, unsorted_topk_ids)
|
||||||
final_hidden_states = unsorted_hidden_states.reshape(
|
final_hidden_states = unsorted_hidden_states.reshape(
|
||||||
self.bsz, self.top_k // self.local_ep, -1).sum(1)
|
self.bsz, self.top_k // self.ep_size, -1).sum(1)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -117,11 +117,11 @@ class EagleProposer(Proposer):
|
|||||||
skip_attn: bool = False,
|
skip_attn: bool = False,
|
||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
||||||
moe_comm_method = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_tokens, with_prefill)
|
num_tokens, with_prefill)
|
||||||
with set_ascend_forward_context(None,
|
with set_ascend_forward_context(None,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
num_tokens=num_tokens):
|
num_tokens=num_tokens):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.input_ids[:num_tokens],
|
input_ids=self.input_ids[:num_tokens],
|
||||||
@@ -454,7 +454,7 @@ class EagleProposer(Proposer):
|
|||||||
with_prefill = attn_metadata.attn_state not in [
|
with_prefill = attn_metadata.attn_state not in [
|
||||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
]
|
]
|
||||||
moe_comm_method = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_input_tokens, with_prefill)
|
num_input_tokens, with_prefill)
|
||||||
|
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
@@ -463,7 +463,7 @@ class EagleProposer(Proposer):
|
|||||||
attn_metadata.block_tables = block_table.to(device)
|
attn_metadata.block_tables = block_table.to(device)
|
||||||
with set_ascend_forward_context(attn_metadata,
|
with set_ascend_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
num_tokens=num_input_tokens):
|
num_tokens=num_input_tokens):
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
@@ -495,7 +495,7 @@ class EagleProposer(Proposer):
|
|||||||
else:
|
else:
|
||||||
input_batch_size = batch_size
|
input_batch_size = batch_size
|
||||||
|
|
||||||
moe_comm_method = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
input_batch_size, False)
|
input_batch_size, False)
|
||||||
|
|
||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
@@ -568,7 +568,7 @@ class EagleProposer(Proposer):
|
|||||||
# Run the model.
|
# Run the model.
|
||||||
with set_ascend_forward_context(attn_metadata,
|
with set_ascend_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
num_tokens=input_batch_size):
|
num_tokens=input_batch_size):
|
||||||
|
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class MtpProposer(Proposer):
|
|||||||
_) = self.runner._sync_metadata_across_dp(num_tokens,
|
_) = self.runner._sync_metadata_across_dp(num_tokens,
|
||||||
with_prefill, False)
|
with_prefill, False)
|
||||||
|
|
||||||
moe_comm_method = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_tokens, with_prefill)
|
num_tokens, with_prefill)
|
||||||
|
|
||||||
is_running_torchair = self.torchair_graph_enabled and \
|
is_running_torchair = self.torchair_graph_enabled and \
|
||||||
@@ -146,7 +146,7 @@ class MtpProposer(Proposer):
|
|||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=0):
|
num_actual_tokens=0):
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
@@ -425,7 +425,7 @@ class MtpProposer(Proposer):
|
|||||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||||
with_prefill = self.runner.with_prefill
|
with_prefill = self.runner.with_prefill
|
||||||
|
|
||||||
moe_comm_method = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_input_tokens, with_prefill)
|
num_input_tokens, with_prefill)
|
||||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||||
uniform_decode=False)
|
uniform_decode=False)
|
||||||
@@ -440,7 +440,7 @@ class MtpProposer(Proposer):
|
|||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=num_tokens):
|
num_actual_tokens=num_tokens):
|
||||||
|
|||||||
@@ -94,7 +94,8 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
|||||||
scatter_mm_placeholders)
|
scatter_mm_placeholders)
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||||
|
set_ascend_forward_context)
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
@@ -1860,7 +1861,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _select_moe_comm_method(self, num_tokens: int,
|
def _select_moe_comm_method(self, num_tokens: int,
|
||||||
with_prefill: bool) -> str:
|
with_prefill: bool) -> MoECommType:
|
||||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||||
are designed for expert parallelism.
|
are designed for expert parallelism.
|
||||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||||
@@ -1881,36 +1882,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
ValueError: If the soc version is unsupported.
|
ValueError: If the soc version is unsupported.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
|
MoECommType: The selected MoE communication method.
|
||||||
"""
|
"""
|
||||||
soc_version = get_ascend_soc_version()
|
soc_version = get_ascend_soc_version()
|
||||||
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
||||||
'moe_quantize', None)
|
'moe_quantize', None)
|
||||||
|
model_type = self.vllm_config.model_config.hf_config.model_type
|
||||||
|
|
||||||
if not self.parallel_config.enable_expert_parallel:
|
if not self.parallel_config.enable_expert_parallel:
|
||||||
moe_comm_method = "allgather"
|
moe_comm_type = MoECommType.ALLGATHER
|
||||||
elif soc_version in {AscendSocVersion.A2}:
|
elif soc_version in {AscendSocVersion.A2}:
|
||||||
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
|
if (num_tokens <= self.mc2_tokens_capacity
|
||||||
moe_comm_method = "mc2"
|
and self.parallel_config.world_size_across_dp >= 16):
|
||||||
|
moe_comm_type = MoECommType.MC2
|
||||||
else:
|
else:
|
||||||
|
# Currently, w4a8_dynamic does not support allgatherep
|
||||||
if quant_type == "w4a8_dynamic":
|
if quant_type == "w4a8_dynamic":
|
||||||
moe_comm_method = "alltoall"
|
moe_comm_type = MoECommType.ALLTOALL
|
||||||
else:
|
else:
|
||||||
moe_comm_method = "allgather"
|
moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
|
||||||
elif soc_version in {AscendSocVersion.A3}:
|
elif soc_version in {AscendSocVersion.A3}:
|
||||||
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
|
moe_comm_type = (MoECommType.MC2
|
||||||
|
if num_tokens <= self.mc2_tokens_capacity else
|
||||||
|
MoECommType.ALLTOALL)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
|
|
||||||
if moe_comm_method == "allgather" and with_prefill:
|
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
|
||||||
moe_comm_method = "naivemulticast"
|
moe_comm_type = MoECommType.NAIVE_MULTICAST
|
||||||
|
|
||||||
|
# PanguProMoE only supports allgather
|
||||||
|
if model_type == "PanguProMoE":
|
||||||
|
moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
|
||||||
if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
logger.debug(f"num_tokens: {num_tokens}, "
|
logger.debug(f"num_tokens: {num_tokens}, "
|
||||||
f"moe_comm_method: {moe_comm_method}")
|
f"moe_comm_type: {moe_comm_type}")
|
||||||
|
return moe_comm_type
|
||||||
return moe_comm_method
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@@ -1942,8 +1951,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.eplb_updator.take_update_info_from_eplb_process()
|
self.eplb_updator.take_update_info_from_eplb_process()
|
||||||
|
|
||||||
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
|
moe_comm_type = self._select_moe_comm_method(num_input_tokens,
|
||||||
self.with_prefill)
|
self.with_prefill)
|
||||||
|
|
||||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||||
scheduler_output.total_num_scheduled_tokens
|
scheduler_output.total_num_scheduled_tokens
|
||||||
@@ -1962,7 +1971,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
with_prefill=self.with_prefill,
|
with_prefill=self.with_prefill,
|
||||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
num_actual_tokens=scheduler_output.
|
num_actual_tokens=scheduler_output.
|
||||||
@@ -2351,8 +2360,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||||
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
||||||
|
|
||||||
moe_comm_method = self._select_moe_comm_method(num_tokens,
|
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
|
||||||
with_prefill)
|
|
||||||
|
|
||||||
# If cudagraph_mode.decode_mode() == FULL and
|
# If cudagraph_mode.decode_mode() == FULL and
|
||||||
# cudagraph_mode.seperate_routine(). This means that we are using
|
# cudagraph_mode.seperate_routine(). This means that we are using
|
||||||
@@ -2472,7 +2480,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
in_profile_run=self.in_profile_run,
|
in_profile_run=self.in_profile_run,
|
||||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_type=moe_comm_type,
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
|
|||||||
Reference in New Issue
Block a user