[CI][Bugfix] Quickfix for DPMetaData (#3234)
### What this PR does / why we need it?
Fix `dpmetadata` and `Qwen3MoeSparseMoeBlock` break introduced by
26a7a33b88 (diff-c1550d0a38469d039370567d8981969530cbfffc7302cd1778e7c2c8a9322dea)
NOTE: we maintain a different sp in vllm-ascend with vllm, thus we can
just use `cu_tokens_across_sp(1)` as `cu_tokens_across_dp_cpu`
close https://github.com/vllm-project/vllm-ascend/issues/3236,
https://github.com/vllm-project/vllm-ascend/issues/3239
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
@@ -230,8 +231,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_get_dp_group):
|
||||
# Mock forward context with DP metadata
|
||||
mock_context = MagicMock()
|
||||
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
|
||||
[2, 5, 7])
|
||||
if vllm_version_is("0.10.2"):
|
||||
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
|
||||
[2, 5, 7])
|
||||
else:
|
||||
mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor(
|
||||
[2, 5, 7])
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Setup DP group mock
|
||||
|
||||
@@ -28,7 +28,7 @@ from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
@@ -93,14 +93,18 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
if vllm_version_is("0.10.2"):
|
||||
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
else:
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(
|
||||
16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
|
||||
@@ -26,7 +26,8 @@ from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import AscendSocVersion, vllm_version_is
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
@@ -53,6 +54,10 @@ def mock_dp_and_tp_group(mocker):
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
# init dist env patch
|
||||
if vllm_version_is("0.10.2"):
|
||||
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
else:
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
@@ -80,7 +85,7 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
|
||||
return_value=MagicMock(
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
dp_metadata=dp_metadata,
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
|
||||
@@ -47,6 +47,7 @@ from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
@@ -169,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
|
||||
@@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@@ -414,8 +416,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
else:
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
if rm_router_logits:
|
||||
|
||||
@@ -56,6 +56,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
@@ -311,9 +312,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
|
||||
@@ -1242,8 +1242,12 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
if vllm_version_is("0.10.2"):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
else:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
|
||||
Reference in New Issue
Block a user