diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index dea1e0bf..bc287dd2 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor enable_sp_by_pass() and is_ep_comm ) - if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()): + if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model() and not is_ep_comm): return tensor_model_parallel_all_reduce(x) dp_metadata = forward_context.dp_metadata diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index e47f4590..7e193371 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1595,10 +1595,8 @@ class SpecDecodeBaseProposer(EagleProposer): hidden_states: torch.Tensor, positions: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - if self.is_multimodal_model and _EXTRA_CTX.flash_comm_v1_enabled: - return hidden_states, positions if self.method == "mtp": - if _EXTRA_CTX.flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled and not self.is_multimodal_model: hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions)