Qwen3.5 MoE supports flashcomm v1 (#7644)
cherry pick from https://github.com/vllm-project/vllm-ascend/pull/7486
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
Multimodal models like Qwen3.5 MoE does embedding in model_runner, so
when flash comm is enabled, the first AllGather operation should be
skipped.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
- vLLM version: v0.18.0
- vLLM main:
8b6325758c
---------
Signed-off-by: Wangbingjie <wangbj1207@126.com>
Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,9 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
def test_qwen3_5_27b_distributed_mp_tp4():
|
def test_qwen3_5_27b_distributed_mp_tp4():
|
||||||
@@ -73,3 +75,31 @@ def test_qwen3_5_35b_distributed_mp_tp4_full_decode_only_mtp3():
|
|||||||
}) as vllm_model:
|
}) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
del vllm_model
|
del vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
|
def test_qwen3_5_35b_distributed_mp_tp4_full_decode_only_mtp3_flashcomm():
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
max_tokens = 20
|
||||||
|
with VllmRunner("Qwen/Qwen3.5-35B-A3B",
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
max_model_len=4096,
|
||||||
|
gpu_memory_utilization=0.90,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
compilation_config={
|
||||||
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
|
"cudagraph_capture_sizes": [4, 8, 12, 16],
|
||||||
|
},
|
||||||
|
speculative_config={
|
||||||
|
"method": "qwen3_5_mtp",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
}) as vllm_model:
|
||||||
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
|
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||||
if enable_custom_op():
|
if enable_custom_op():
|
||||||
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
x, residual, 1.0 + self.weight, None, self.variance_epsilon
|
x, residual, 1.0 + self.weight, None, self.variance_epsilon
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_reduce_scatter,
|
tensor_model_parallel_reduce_scatter,
|
||||||
)
|
)
|
||||||
from vllm.distributed.parallel_state import get_tp_group
|
from vllm.distributed.parallel_state import get_tp_group
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
@@ -74,6 +75,7 @@ from vllm_ascend.utils import (
|
|||||||
flashcomm2_enable,
|
flashcomm2_enable,
|
||||||
get_flashcomm2_reorgnized_batch_ids,
|
get_flashcomm2_reorgnized_batch_ids,
|
||||||
get_weight_prefetch_method,
|
get_weight_prefetch_method,
|
||||||
|
is_vl_model,
|
||||||
matmul_allreduce_enable,
|
matmul_allreduce_enable,
|
||||||
mlp_tp_enable,
|
mlp_tp_enable,
|
||||||
oproj_tp_enable,
|
oproj_tp_enable,
|
||||||
@@ -430,8 +432,8 @@ class SequenceColumnParallelOp(CustomColumnParallelOp):
|
|||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
need_all_gather = not (extract_layer_index(self.layer.prefix) == 0 and is_vl_model() and "attn" in self.prefix)
|
||||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, label=need_all_gather)
|
||||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
|||||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
|
from vllm_ascend.utils import enable_sp_by_pass, is_vl_model, npu_stream_switch, prefetch_stream
|
||||||
|
|
||||||
|
|
||||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
|||||||
enable_sp_by_pass() and is_ep_comm
|
enable_sp_by_pass() and is_ep_comm
|
||||||
)
|
)
|
||||||
|
|
||||||
if not flash_comm_v1_enabled:
|
if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()):
|
||||||
return tensor_model_parallel_all_reduce(x)
|
return tensor_model_parallel_all_reduce(x)
|
||||||
|
|
||||||
dp_metadata = forward_context.dp_metadata
|
dp_metadata = forward_context.dp_metadata
|
||||||
|
|||||||
@@ -18,15 +18,18 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
|
||||||
from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet
|
from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
|
||||||
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention
|
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention
|
||||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
|
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
|
||||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||||
@@ -41,6 +44,64 @@ def to_int64_tuple(t):
|
|||||||
|
|
||||||
|
|
||||||
class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with three parts:
|
||||||
|
1. Input projection
|
||||||
|
2. Core attention (custom op)
|
||||||
|
3. Output projection
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 1: Input Projection
|
||||||
|
# ============================================================
|
||||||
|
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||||
|
num_tokens = mixed_qkvz.size(0)
|
||||||
|
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
|
||||||
|
z_size = self.value_dim // self.tp_size
|
||||||
|
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
|
||||||
|
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
||||||
|
ba, _ = self.in_proj_ba(hidden_states)
|
||||||
|
b, a = ba.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
b = b.contiguous()
|
||||||
|
a = a.contiguous()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 2: Core Attention (Custom Op)
|
||||||
|
# ============================================================
|
||||||
|
# Note: we should not use torch.empty here like other attention backends,
|
||||||
|
# see discussions in https://github.com/vllm-project/vllm/pull/28182
|
||||||
|
core_attn_out = torch.zeros(
|
||||||
|
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
torch.ops.vllm.gdn_attention_core(
|
||||||
|
mixed_qkv,
|
||||||
|
b,
|
||||||
|
a,
|
||||||
|
core_attn_out,
|
||||||
|
self.prefix,
|
||||||
|
)
|
||||||
|
# ============================================================
|
||||||
|
# Part 3: Output Projection
|
||||||
|
# ============================================================
|
||||||
|
z_shape_og = z.shape
|
||||||
|
# Reshape input data into 2D tensor
|
||||||
|
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||||
|
z = z.reshape(-1, z.shape[-1])
|
||||||
|
core_attn_out = self.norm(core_attn_out, z)
|
||||||
|
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||||
|
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||||
|
o_out, _ = self.out_proj(core_attn_out)
|
||||||
|
actual_num_tokens = o_out.shape[0]
|
||||||
|
output[:actual_num_tokens] = o_out
|
||||||
|
|
||||||
def _forward_core(
|
def _forward_core(
|
||||||
self,
|
self,
|
||||||
mixed_qkv: torch.Tensor,
|
mixed_qkv: torch.Tensor,
|
||||||
@@ -320,5 +381,68 @@ class AscendQwen3NextAttention(Qwen3NextAttention):
|
|||||||
output[:], _ = self.o_proj(attn_output)
|
output[:], _ = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendQwen3_5DecoderLayer(Qwen3_5DecoderLayer):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor = None,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
if self.layer_idx == 0 and _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
n_out = (hidden_states.shape[0] + tp_size - 1) // tp_size
|
||||||
|
hidden_dim = hidden_states.shape[-1]
|
||||||
|
self_attention_output = torch.empty(
|
||||||
|
(n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self_attention_output = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
self.linear_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
output=self_attention_output,
|
||||||
|
)
|
||||||
|
elif self.layer_type == "full_attention":
|
||||||
|
self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
output=self_attention_output,
|
||||||
|
positions=positions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid layer_type")
|
||||||
|
hidden_states = self_attention_output
|
||||||
|
|
||||||
|
if self.layer_scale:
|
||||||
|
if len(hidden_states.shape) == 2:
|
||||||
|
hidden_states = hidden_states * (self.attn_layer_scale.to(hidden_states.dtype)[0] + 1)
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states * (self.attn_layer_scale.to(hidden_states.dtype) + 1)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
if self.layer_scale:
|
||||||
|
if len(hidden_states.shape) == 2:
|
||||||
|
hidden_states = hidden_states * (self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1)
|
||||||
|
else:
|
||||||
|
assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
|
||||||
|
f"shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}"
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states * (self.ffn_layer_scale.to(hidden_states.dtype) + 1)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
Qwen3_5GatedDeltaNet.forward = AscendQwen3_5GatedDeltaNet.forward
|
||||||
Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core
|
Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core
|
||||||
Qwen3NextAttention.forward = AscendQwen3NextAttention.forward
|
Qwen3NextAttention.forward = AscendQwen3NextAttention.forward
|
||||||
|
Qwen3_5DecoderLayer.forward = AscendQwen3_5DecoderLayer.forward
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self._runnable = self._run_merged_draft
|
self._runnable = self._run_merged_draft
|
||||||
|
self.is_multimodal_model = self.vllm_config.model_config.is_multimodal_model
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device)
|
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device)
|
||||||
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
|
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
|
||||||
@@ -414,6 +415,16 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
if is_profile:
|
if is_profile:
|
||||||
batch_size = min(batch_size, self.runner.max_num_reqs)
|
batch_size = min(batch_size, self.runner.max_num_reqs)
|
||||||
|
|
||||||
|
if self.supports_mm_inputs:
|
||||||
|
mm_embeds, is_mm_embed = (None, None)
|
||||||
|
inputs_embeds = self.model.embed_input_ids(
|
||||||
|
self.input_ids[:num_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||||
|
)
|
||||||
|
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
|
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
@@ -437,7 +448,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request],
|
token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request],
|
||||||
# The target_position's address is same as the model_positions's
|
# The target_position's address is same as the model_positions's
|
||||||
target_positions=model_positions,
|
target_positions=model_positions,
|
||||||
inputs_embeds=None,
|
inputs_embeds=inputs_embeds,
|
||||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
)
|
)
|
||||||
@@ -1581,6 +1592,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.is_multimodal_model and _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
|
return hidden_states, positions
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
if _EXTRA_CTX.flash_comm_v1_enabled:
|
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
|
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
|
||||||
|
|||||||
@@ -849,7 +849,7 @@ def _is_contain_expert(config: Any):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_vl_model(vllm_config: VllmConfig):
|
def is_vl_model(vllm_config: VllmConfig = None):
|
||||||
"""Checks if the model is a VL model by config.
|
"""Checks if the model is a VL model by config.
|
||||||
|
|
||||||
Uses the same criterion as vllm itself (model_config.py): a model is
|
Uses the same criterion as vllm itself (model_config.py): a model is
|
||||||
@@ -859,6 +859,10 @@ def is_vl_model(vllm_config: VllmConfig):
|
|||||||
self (rare but possible).
|
self (rare but possible).
|
||||||
"""
|
"""
|
||||||
global _IS_VL_MODEL
|
global _IS_VL_MODEL
|
||||||
|
if vllm_config is None:
|
||||||
|
from vllm.config import get_current_vllm_config_or_none
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config_or_none()
|
||||||
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
|
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
# Primary: vllm's own VL detection — hf_config is the top-level
|
# Primary: vllm's own VL detection — hf_config is the top-level
|
||||||
|
|||||||
Reference in New Issue
Block a user