diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 7fae512b..4920b1f7 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -332,3 +332,14 @@ # https://github.com/vllm-project/vllm/pull/36225 # Future Plan: # Remove this patch when vLLM merges the PR. +# +# ** 17. File: worker/patch_qwen3_5.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3_5.Qwen3_5GatedDeltaNet._forward_core` +# Why: +# The class Qwen3_5GatedDeltaNet reuse the `_forward_core` method of Qwen3NextGatedDeltaNet, +# but the ascendC ops of Qwen3NextGatedDeltaNet do not support ssm_state with float32 format. +# How: +# patch Qwen3_5GatedDeltaNet._forward_core to use triton ops like `fused_recurrent_gated_delta_rule`. +# Future Plan: +# Remove this patch when all ops in _forward_core support both Qwen3_5 and Qwen3Next. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index f5a50aaf..5f2d9a2a 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -17,9 +17,14 @@ from vllm.triton_utils import HAS_TRITON +from vllm_ascend.utils import vllm_version_is + if HAS_TRITON: import vllm_ascend.patch.worker.patch_triton +if not vllm_version_is("v0.16.0"): + import vllm_ascend.patch.worker.patch_qwen3_5 # noqa + # isort: off import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py new file mode 100644 index 00000000..536e4695 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -0,0 +1,257 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from collections.abc import Iterable +# mypy: ignore-errors + + +import torch +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update +from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet +from vllm.v1.attention.backend import AttentionMetadata # type: ignore +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm.v1.attention.backends.utils import PAD_SLOT_ID + +from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector +from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch +from vllm_ascend.utils import enable_sp + + +class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + # Core attention computation (called by custom op). + + # NOTE: The processing logic of Qwen3_5GatedDeltaNet is the same as Qwen3NextGatedDeltaNet. + # However, because the ops `torch_npu.npu_recurrent_gated_delta_rule` + # currently does not support `ssm_state` inputs in float32 format, + # we temporarily retain the current _forward_core implementation. + # Once the ops supports float32 `ssm_state`, this patch should be removed. + + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + + if not enable_sp(): + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + if mixed_qkv_non_spec is not None: + conv_weights_T = conv_weights.transpose(0, 1) + mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn( + mixed_qkv_non_spec, + conv_weights_T, + self.conv1d.bias, + activation=self.activation, + conv_state=self_kv_cache[0], + has_initial_state=has_initial_state, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + non_spec_query_start_loc=non_spec_query_start_loc, + pad_slot_id=PAD_SLOT_ID, + ) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) + + if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: + g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2. Recurrent attention + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log.contiguous(), + dt_bias=self.dt_bias.contiguous(), + q=query_non_spec.contiguous(), + k=key_non_spec.contiguous(), + v=value_non_spec.contiguous(), + a=a.contiguous(), + b=b.contiguous(), + initial_state_source=ssm_state, + initial_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + if not enable_sp(): + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] + elif spec_sequence_masks is not None: + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] + else: + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] + maybe_save_kv_layer_to_connector("", []) + + +Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 9a498a28..b821524a 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -189,6 +189,8 @@ class AscendEagleProposer(EagleProposer): "Qwen2_5_VLForConditionalGeneration", "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration", + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", ]: self.model.config.image_token_index = model.config.image_token_id elif self.get_model_name(model) == "PixtralForConditionalGeneration":