Add Ascend Ops recurrent_gated_delta_rule (#6725)
### What this PR does / why we need it?
Change recurrent_gated_delta_rule ops from triton to ascend C version
for better performance.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -24,6 +24,8 @@ _server_cmd: &server_cmd
|
||||
- "0.8"
|
||||
- "--max-num-seqs"
|
||||
- "64"
|
||||
- "--compilation-config"
|
||||
- '{"cudagraph_capture_sizes": [64]}'
|
||||
|
||||
_benchmarks: &benchmarks
|
||||
perf:
|
||||
@@ -42,7 +44,7 @@ _benchmarks: &benchmarks
|
||||
request_conf: vllm_api_general_chat
|
||||
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
|
||||
max_out_len: 32768
|
||||
batch_size: 32
|
||||
batch_size: 64
|
||||
top_k: 20
|
||||
baseline: 95
|
||||
threshold: 5
|
||||
|
||||
@@ -14,13 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# from collections.abc import Iterable
|
||||
# mypy: ignore-errors
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
|
||||
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||
from vllm.triton_utils import triton
|
||||
@@ -28,12 +29,11 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
|
||||
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -191,99 +191,88 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_non_spec = None
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
|
||||
|
||||
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
|
||||
# 2. Recurrent attention
|
||||
|
||||
# 2.1: Process the multi-query part
|
||||
if spec_sequence_masks is not None:
|
||||
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
g=g_spec,
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
if spec_sequence_masks is not None:
|
||||
cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1]
|
||||
actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
query_spec = l2norm_fwd(query_spec)
|
||||
key_spec = l2norm_fwd(key_spec)
|
||||
core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule(
|
||||
query=query_spec.squeeze(0),
|
||||
key=key_spec.squeeze(0),
|
||||
value=value_spec.squeeze(0),
|
||||
g=g_spec.squeeze(0),
|
||||
beta=beta_spec.squeeze(0),
|
||||
state=ssm_state,
|
||||
scale=key_spec.shape[-1] ** -0.5,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
ssm_state_indices=spec_state_indices_tensor.flatten(),
|
||||
num_accepted_tokens=num_accepted_tokens.to(torch.int32),
|
||||
).unsqueeze(0)
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
|
||||
# 2.2: Process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
# Init cache
|
||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
core_attn_out_non_spec, last_recurrent_state = None, None
|
||||
# 2.2: Process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous()
|
||||
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
ssm_state[non_spec_state_indices_tensor] = (
|
||||
last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype)
|
||||
)
|
||||
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log.contiguous(),
|
||||
dt_bias=self.dt_bias.contiguous(),
|
||||
q=query_non_spec.contiguous(),
|
||||
k=key_non_spec.contiguous(),
|
||||
v=value_non_spec.contiguous(),
|
||||
a=a.contiguous(),
|
||||
b=b.contiguous(),
|
||||
initial_state_source=ssm_state,
|
||||
initial_state_indices=non_spec_state_indices_tensor,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1]
|
||||
actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
query_non_spec = l2norm_fwd(query_non_spec)
|
||||
key_non_spec = l2norm_fwd(key_non_spec)
|
||||
core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule(
|
||||
query=query_non_spec.squeeze(0),
|
||||
key=key_non_spec.squeeze(0),
|
||||
value=value_non_spec.squeeze(0),
|
||||
g=g_non_spec.squeeze(0),
|
||||
beta=beta_non_spec.squeeze(0),
|
||||
state=ssm_state,
|
||||
scale=key_non_spec.shape[-1] ** -0.5,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
).unsqueeze(0)
|
||||
else:
|
||||
core_attn_out_non_spec, last_recurrent_state = None, None
|
||||
|
||||
# 3. Merge core attention output
|
||||
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
|
||||
|
||||
Reference in New Issue
Block a user