Add Ascend Ops recurrent_gated_delta_rule (#6725)

### What this PR does / why we need it?
Change recurrent_gated_delta_rule ops from triton to ascend C version
for better performance.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2026-03-09 14:14:14 +08:00
committed by GitHub
parent 23bf5d4d48
commit 65eae6de7b
2 changed files with 83 additions and 92 deletions

View File

@@ -24,6 +24,8 @@ _server_cmd: &server_cmd
- "0.8"
- "--max-num-seqs"
- "64"
- "--compilation-config"
- '{"cudagraph_capture_sizes": [64]}'
_benchmarks: &benchmarks
perf:
@@ -42,7 +44,7 @@ _benchmarks: &benchmarks
request_conf: vllm_api_general_chat
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
max_out_len: 32768
batch_size: 32
batch_size: 64
top_k: 20
baseline: 95
threshold: 5

View File

@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# from collections.abc import Iterable
# mypy: ignore-errors
import torch
import torch_npu
from einops import rearrange
from torch import nn
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.triton_utils import triton
@@ -28,12 +29,11 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
from vllm_ascend.utils import enable_sp
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
def forward(
self,
hidden_states: torch.Tensor,
@@ -191,99 +191,88 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
# 2. Recurrent attention
# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_spec, last_recurrent_state = None, None
if spec_sequence_masks is not None:
cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1]
actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
query_spec = l2norm_fwd(query_spec)
key_spec = l2norm_fwd(key_spec)
core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule(
query=query_spec.squeeze(0),
key=key_spec.squeeze(0),
value=value_spec.squeeze(0),
g=g_spec.squeeze(0),
beta=beta_spec.squeeze(0),
state=ssm_state,
scale=key_spec.shape[-1] ** -0.5,
actual_seq_lengths=actual_seq_lengths,
ssm_state_indices=spec_state_indices_tensor.flatten(),
num_accepted_tokens=num_accepted_tokens.to(torch.int32),
).unsqueeze(0)
else:
core_attn_out_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
ssm_state[non_spec_state_indices_tensor] = (
last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype)
)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log.contiguous(),
dt_bias=self.dt_bias.contiguous(),
q=query_non_spec.contiguous(),
k=key_non_spec.contiguous(),
v=value_non_spec.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)
cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1]
actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
query_non_spec = l2norm_fwd(query_non_spec)
key_non_spec = l2norm_fwd(key_non_spec)
core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule(
query=query_non_spec.squeeze(0),
key=key_non_spec.squeeze(0),
value=value_non_spec.squeeze(0),
g=g_non_spec.squeeze(0),
beta=beta_non_spec.squeeze(0),
state=ssm_state,
scale=key_non_spec.shape[-1] ** -0.5,
actual_seq_lengths=actual_seq_lengths,
ssm_state_indices=non_spec_state_indices_tensor,
).unsqueeze(0)
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None: