[Bugfix] Disable the dispatch_ffn_combine kernel in MTP path (#4751)
### What this PR does / why we need it?
This PR is to fix a smoking test failure. Adjust mtp_proposer and
model_runner_v1 to route MTP decoding through the non‑fused MoE
implementation while keeping the overall inference flow unchanged.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: mojave2 <chenchen145@huawei.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -27,7 +27,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||||
|
set_ascend_forward_context)
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
@@ -237,6 +238,9 @@ class MtpProposer(Proposer):
|
|||||||
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
|
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
|
||||||
|
|
||||||
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
||||||
|
# TODO: remove this after moe_comm_type selection logic is finalized
|
||||||
|
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
|
||||||
|
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
|
||||||
|
|
||||||
if skip_attn:
|
if skip_attn:
|
||||||
attn_metadata = None
|
attn_metadata = None
|
||||||
|
|||||||
@@ -52,8 +52,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
||||||
get_pcp_group, get_pp_group,
|
get_ep_group, get_pcp_group,
|
||||||
get_tp_group,
|
get_pp_group, get_tp_group,
|
||||||
is_global_first_rank)
|
is_global_first_rank)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -2332,10 +2332,11 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
moe_comm_type = MoECommType.ALLGATHER
|
moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
|
||||||
elif soc_version in {AscendDeviceType._910_93}:
|
elif soc_version in {AscendDeviceType._910_93}:
|
||||||
moe_comm_type = (MoECommType.MC2
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||||
if num_tokens <= self.mc2_tokens_capacity else
|
moe_comm_type = (
|
||||||
MoECommType.FUSED_ALLTOALL if quant_type
|
MoECommType.MC2 if num_tokens <= self.mc2_tokens_capacity else
|
||||||
== "w8a8_dynamic" else MoECommType.ALLTOALL)
|
MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic"
|
||||||
|
and get_ep_group().world_size <= 16 else MoECommType.ALLTOALL)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user