Qwen3.5 MoE supports flashcomm v1 (#7644)
cherry pick from https://github.com/vllm-project/vllm-ascend/pull/7486
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
Multimodal models like Qwen3.5 MoE does embedding in model_runner, so
when flash comm is enabled, the first AllGather operation should be
skipped.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
- vLLM version: v0.18.0
- vLLM main:
8b6325758c
---------
Signed-off-by: Wangbingjie <wangbj1207@126.com>
Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
This commit is contained in:
@@ -97,6 +97,7 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
if enable_custom_op():
|
||||
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, 1.0 + self.weight, None, self.variance_epsilon
|
||||
|
||||
@@ -57,6 +57,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
@@ -74,6 +75,7 @@ from vllm_ascend.utils import (
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
get_weight_prefetch_method,
|
||||
is_vl_model,
|
||||
matmul_allreduce_enable,
|
||||
mlp_tp_enable,
|
||||
oproj_tp_enable,
|
||||
@@ -430,8 +432,8 @@ class SequenceColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
need_all_gather = not (extract_layer_index(self.layer.prefix) == 0 and is_vl_model() and "attn" in self.prefix)
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, label=need_all_gather)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
|
||||
@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
|
||||
from vllm_ascend.utils import enable_sp_by_pass, is_vl_model, npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
@@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
enable_sp_by_pass() and is_ep_comm
|
||||
)
|
||||
|
||||
if not flash_comm_v1_enabled:
|
||||
if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()):
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
Reference in New Issue
Block a user