[CI][Bugfix] Quickfix for DPMetaData (#3234)

### What this PR does / why we need it?
Fix `dpmetadata` and `Qwen3MoeSparseMoeBlock` break introduced by
26a7a33b88 (diff-c1550d0a38469d039370567d8981969530cbfffc7302cd1778e7c2c8a9322dea)

NOTE: we maintain a different sp in vllm-ascend with vllm, thus we can
just use `cu_tokens_across_sp(1)` as `cu_tokens_across_dp_cpu`

close https://github.com/vllm-project/vllm-ascend/issues/3236,
https://github.com/vllm-project/vllm-ascend/issues/3239
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.10.2
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-09-28 21:11:22 +08:00
committed by GitHub
parent f2d8493221
commit 4ff422c730
7 changed files with 59 additions and 23 deletions

View File

@@ -8,6 +8,7 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All, FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast) FusedMoEPrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.utils import vllm_version_is
class TestFusedMoEPrepareAndFinalize(unittest.TestCase): class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
@@ -230,8 +231,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_get_dp_group): mock_get_dp_group):
# Mock forward context with DP metadata # Mock forward context with DP metadata
mock_context = MagicMock() mock_context = MagicMock()
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor( if vllm_version_is("0.10.2"):
[2, 5, 7]) mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
[2, 5, 7])
else:
mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor(
[2, 5, 7])
mock_get_forward_context.return_value = mock_context mock_get_forward_context.return_value = mock_context
# Setup DP group mock # Setup DP group mock

View File

@@ -28,7 +28,7 @@ from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod) AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is
adapt_patch(True) adapt_patch(True)
@@ -93,14 +93,18 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method.finalize.side_effect = mock_finalize mock_moe_comm_method.finalize.side_effect = mock_finalize
mock_forward_context_obj = MagicMock( if vllm_version_is("0.10.2"):
moe_comm_method=mock_moe_comm_method, dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
moe_comm_type=MoECommType.MC2, else:
max_tokens_across_dp=10, dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
mc2_mask=torch.zeros(16, dtype=torch.bool), moe_comm_type=MoECommType.MC2,
padded_num_tokens=16, max_tokens_across_dp=10,
with_quant=False) dp_metadata=dp_metadata,
mc2_mask=torch.zeros(
16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \

View File

@@ -26,7 +26,8 @@ from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import ( from vllm_ascend.torchair.ops.torchair_fused_moe import (
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod) TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 from vllm_ascend.utils import adapt_patch # noqa E402
from vllm_ascend.utils import AscendSocVersion, vllm_version_is
adapt_patch(True) adapt_patch(True)
@@ -53,6 +54,10 @@ def mock_dp_and_tp_group(mocker):
@pytest.fixture @pytest.fixture
def mock_dist_env(mocker: MockerFixture): def mock_dist_env(mocker: MockerFixture):
# init dist env patch # init dist env patch
if vllm_version_is("0.10.2"):
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
else:
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
@@ -80,7 +85,7 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context', patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
return_value=MagicMock( return_value=MagicMock(
max_tokens_across_dp=10, max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]) dp_metadata=dp_metadata,
)), \ )), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config', patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
return_value=MagicMock( return_value=MagicMock(

View File

@@ -47,6 +47,7 @@ from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import vllm_version_is
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -169,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config, if vllm_version_is("0.10.2"):
quant_config=quant_config, self.mlp = Qwen3MoeSparseMoeBlock(
prefix=f"{prefix}.mlp") config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,

View File

@@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import vllm_version_is
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
@@ -414,8 +416,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1: if self.moe_config.dp_size > 1:
self.cu_tokens_across_dp_cpu = get_forward_context( if vllm_version_is("0.10.2"):
).dp_metadata.cu_tokens_across_dp_cpu self.cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
else:
self.cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self._naive_multicast(hidden_states, hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu) self.cu_tokens_across_dp_cpu)
if rm_router_logits: if rm_router_logits:

View File

@@ -56,6 +56,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp) init_metadata_for_sp)
from vllm_ascend.utils import vllm_version_is
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -311,9 +312,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config, if vllm_version_is("0.10.2"):
quant_config=quant_config, self.mlp = Qwen3MoeSparseMoeBlock(
prefix=f"{prefix}.mlp") config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,

View File

@@ -1242,8 +1242,12 @@ class TorchairAscendFusedMoE(FusedMoE):
router_logits = get_dp_group().all_gather(router_logits, 0) router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast: elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context( if vllm_version_is("0.10.2"):
).dp_metadata.cu_tokens_across_dp_cpu cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
else:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self.naive_multicast(hidden_states, hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu) cu_tokens_across_dp_cpu)
if self.rm_router_logits: if self.rm_router_logits: