diff --git a/tests/e2e/nightly/single_node/models/configs/Qwen3-Next-80B-A3B-Instruct.yaml b/tests/e2e/nightly/single_node/models/configs/Qwen3-Next-80B-A3B-Instruct.yaml index 31d4aaf4..fc126efa 100644 --- a/tests/e2e/nightly/single_node/models/configs/Qwen3-Next-80B-A3B-Instruct.yaml +++ b/tests/e2e/nightly/single_node/models/configs/Qwen3-Next-80B-A3B-Instruct.yaml @@ -24,6 +24,8 @@ _server_cmd: &server_cmd - "0.8" - "--max-num-seqs" - "64" + - "--compilation-config" + - '{"cudagraph_capture_sizes": [64]}' _benchmarks: &benchmarks perf: @@ -42,7 +44,7 @@ _benchmarks: &benchmarks request_conf: vllm_api_general_chat dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt max_out_len: 32768 - batch_size: 32 + batch_size: 64 top_k: 20 baseline: 95 threshold: 5 diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 8b642587..1e36e3a0 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # from collections.abc import Iterable +# mypy: ignore-errors import torch +import torch_npu from einops import rearrange -from torch import nn from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule -from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet from vllm.triton_utils import triton @@ -28,12 +29,11 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat -from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch from vllm_ascend.utils import enable_sp -class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): +class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): def forward( self, hidden_states: torch.Tensor, @@ -191,99 +191,88 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): mixed_qkv_non_spec = None query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) - - if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: - g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) + g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta # 2. Recurrent attention - # 2.1: Process the multi-query part - if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, - ) - else: - core_attn_out_spec, last_recurrent_state = None, None + if spec_sequence_masks is not None: + cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1] + actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + query_spec = l2norm_fwd(query_spec) + key_spec = l2norm_fwd(key_spec) + core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule( + query=query_spec.squeeze(0), + key=key_spec.squeeze(0), + value=value_spec.squeeze(0), + g=g_spec.squeeze(0), + beta=beta_spec.squeeze(0), + state=ssm_state, + scale=key_spec.shape[-1] ** -0.5, + actual_seq_lengths=actual_seq_lengths, + ssm_state_indices=spec_state_indices_tensor.flatten(), + num_accepted_tokens=num_accepted_tokens.to(torch.int32), + ).unsqueeze(0) + else: + core_attn_out_spec, last_recurrent_state = None, None - # 2.2: Process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 - ( - core_attn_out_non_spec, - last_recurrent_state, - ) = chunk_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=initial_state, - output_final_state=True, - cu_seqlens=non_spec_query_start_loc, - head_first=False, - use_qk_l2norm_in_kernel=True, - ) - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, - use_qk_l2norm_in_kernel=True, - ) - else: - core_attn_out_non_spec, last_recurrent_state = None, None + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous() + + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + ssm_state[non_spec_state_indices_tensor] = ( + last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype) + ) elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( - A_log=self.A_log.contiguous(), - dt_bias=self.dt_bias.contiguous(), - q=query_non_spec.contiguous(), - k=key_non_spec.contiguous(), - v=value_non_spec.contiguous(), - a=a.contiguous(), - b=b.contiguous(), - initial_state_source=ssm_state, - initial_state_indices=non_spec_state_indices_tensor, - cu_seqlens=non_spec_query_start_loc, - use_qk_l2norm_in_kernel=True, - softplus_beta=1.0, - softplus_threshold=20.0, - ) + cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1] + actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + query_non_spec = l2norm_fwd(query_non_spec) + key_non_spec = l2norm_fwd(key_non_spec) + core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule( + query=query_non_spec.squeeze(0), + key=key_non_spec.squeeze(0), + value=value_non_spec.squeeze(0), + g=g_non_spec.squeeze(0), + beta=beta_non_spec.squeeze(0), + state=ssm_state, + scale=key_non_spec.shape[-1] ** -0.5, + actual_seq_lengths=actual_seq_lengths, + ssm_state_indices=non_spec_state_indices_tensor, + ).unsqueeze(0) + else: + core_attn_out_non_spec, last_recurrent_state = None, None # 3. Merge core attention output if spec_sequence_masks is not None and core_attn_out_non_spec is not None: