diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py index 3c78d2f9..9c9a6b39 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -33,7 +33,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch -from vllm_ascend.utils import enable_sp, vllm_version_is +from vllm_ascend.utils import vllm_version_is def to_int64_tuple(t): @@ -141,10 +141,9 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens - if not enable_sp(): - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -315,20 +314,11 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): ) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - if not enable_sp(): - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) elif spec_sequence_masks is not None: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) else: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) maybe_save_kv_layer_to_connector("", []) diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index ff7e0c22..47648ab7 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -33,7 +33,7 @@ from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch from vllm_ascend.patch.worker.patch_qwen3_5 import to_int64_tuple -from vllm_ascend.utils import enable_sp, vllm_version_is +from vllm_ascend.utils import vllm_version_is class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): @@ -131,10 +131,9 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens - if not enable_sp(): - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -294,20 +293,11 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): ) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - if not enable_sp(): - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) elif spec_sequence_masks is not None: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) else: - if not enable_sp(): - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0a9c1986..6cbd26ce 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2201,7 +2201,7 @@ class NPUModelRunner(GPUModelRunner): if self._has_gdn: attn_group = self.attn_groups[kv_cache_gid][0] builder = attn_group.get_metadata_builder(0) - if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + if isinstance(builder, GDNAttentionMetadataBuilder): cm.query_start_loc_cpu = self.gdn_query_start_loc.cpu[: num_reqs_padded + 1] cm.query_start_loc = self.gdn_query_start_loc.gpu[: num_reqs_padded + 1] @@ -2399,6 +2399,9 @@ class NPUModelRunner(GPUModelRunner): num_reqs_padded = self._pad_query_start_loc_for_fia( num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_runtime_mode, batch_desc.num_reqs ) + if self._has_gdn: + self.gdn_query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens + self.gdn_query_start_loc.copy_to_gpu() pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata(