Add Ascend Ops recurrent_gated_delta_rule (#6725)

### What this PR does / why we need it?
Change recurrent_gated_delta_rule ops from triton to ascend C version
for better performance.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2026-03-09 14:14:14 +08:00
committed by GitHub
parent 23bf5d4d48
commit 65eae6de7b
2 changed files with 83 additions and 92 deletions

View File

@@ -24,6 +24,8 @@ _server_cmd: &server_cmd
- "0.8" - "0.8"
- "--max-num-seqs" - "--max-num-seqs"
- "64" - "64"
- "--compilation-config"
- '{"cudagraph_capture_sizes": [64]}'
_benchmarks: &benchmarks _benchmarks: &benchmarks
perf: perf:
@@ -42,7 +44,7 @@ _benchmarks: &benchmarks
request_conf: vllm_api_general_chat request_conf: vllm_api_general_chat
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
max_out_len: 32768 max_out_len: 32768
batch_size: 32 batch_size: 64
top_k: 20 top_k: 20
baseline: 95 baseline: 95
threshold: 5 threshold: 5

View File

@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# from collections.abc import Iterable # from collections.abc import Iterable
# mypy: ignore-errors
import torch import torch
import torch_npu
from einops import rearrange from einops import rearrange
from torch import nn
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.triton_utils import triton from vllm.triton_utils import triton
@@ -28,12 +29,11 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
from vllm_ascend.utils import enable_sp from vllm_ascend.utils import enable_sp
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -191,8 +191,6 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec = None mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
@@ -212,28 +210,31 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
beta_non_spec = beta beta_non_spec = beta
# 2. Recurrent attention # 2. Recurrent attention
# 2.1: Process the multi-query part # 2.1: Process the multi-query part
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1]
q=query_spec, actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
k=key_spec, query_spec = l2norm_fwd(query_spec)
v=value_spec, key_spec = l2norm_fwd(key_spec)
g=g_spec, core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule(
beta=beta_spec, query=query_spec.squeeze(0),
initial_state=ssm_state, key=key_spec.squeeze(0),
inplace_final_state=True, value=value_spec.squeeze(0),
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], g=g_spec.squeeze(0),
ssm_state_indices=spec_state_indices_tensor, beta=beta_spec.squeeze(0),
num_accepted_tokens=num_accepted_tokens, state=ssm_state,
use_qk_l2norm_in_kernel=True, scale=key_spec.shape[-1] ** -0.5,
) actual_seq_lengths=actual_seq_lengths,
ssm_state_indices=spec_state_indices_tensor.flatten(),
num_accepted_tokens=num_accepted_tokens.to(torch.int32),
).unsqueeze(0)
else: else:
core_attn_out_spec, last_recurrent_state = None, None core_attn_out_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part # 2.2: Process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous()
initial_state[~has_initial_state, ...] = 0 initial_state[~has_initial_state, ...] = 0
( (
core_attn_out_non_spec, core_attn_out_non_spec,
@@ -250,40 +251,28 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
head_first=False, head_first=False,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
) )
# Init cache ssm_state[non_spec_state_indices_tensor] = (
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
) )
else:
core_attn_out_non_spec, last_recurrent_state = None, None
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1]
A_log=self.A_log.contiguous(), actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
dt_bias=self.dt_bias.contiguous(), query_non_spec = l2norm_fwd(query_non_spec)
q=query_non_spec.contiguous(), key_non_spec = l2norm_fwd(key_non_spec)
k=key_non_spec.contiguous(), core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule(
v=value_non_spec.contiguous(), query=query_non_spec.squeeze(0),
a=a.contiguous(), key=key_non_spec.squeeze(0),
b=b.contiguous(), value=value_non_spec.squeeze(0),
initial_state_source=ssm_state, g=g_non_spec.squeeze(0),
initial_state_indices=non_spec_state_indices_tensor, beta=beta_non_spec.squeeze(0),
cu_seqlens=non_spec_query_start_loc, state=ssm_state,
use_qk_l2norm_in_kernel=True, scale=key_non_spec.shape[-1] ** -0.5,
softplus_beta=1.0, actual_seq_lengths=actual_seq_lengths,
softplus_threshold=20.0, ssm_state_indices=non_spec_state_indices_tensor,
) ).unsqueeze(0)
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# 3. Merge core attention output # 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None: if spec_sequence_masks is not None and core_attn_out_non_spec is not None: