Add Ascend Ops recurrent_gated_delta_rule (#6725)
### What this PR does / why we need it?
Change recurrent_gated_delta_rule ops from triton to ascend C version
for better performance.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -24,6 +24,8 @@ _server_cmd: &server_cmd
|
|||||||
- "0.8"
|
- "0.8"
|
||||||
- "--max-num-seqs"
|
- "--max-num-seqs"
|
||||||
- "64"
|
- "64"
|
||||||
|
- "--compilation-config"
|
||||||
|
- '{"cudagraph_capture_sizes": [64]}'
|
||||||
|
|
||||||
_benchmarks: &benchmarks
|
_benchmarks: &benchmarks
|
||||||
perf:
|
perf:
|
||||||
@@ -42,7 +44,7 @@ _benchmarks: &benchmarks
|
|||||||
request_conf: vllm_api_general_chat
|
request_conf: vllm_api_general_chat
|
||||||
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
|
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
|
||||||
max_out_len: 32768
|
max_out_len: 32768
|
||||||
batch_size: 32
|
batch_size: 64
|
||||||
top_k: 20
|
top_k: 20
|
||||||
baseline: 95
|
baseline: 95
|
||||||
threshold: 5
|
threshold: 5
|
||||||
|
|||||||
@@ -14,13 +14,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# from collections.abc import Iterable
|
# from collections.abc import Iterable
|
||||||
|
# mypy: ignore-errors
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
@@ -28,12 +29,11 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
|||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||||
|
|
||||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
|
||||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||||
from vllm_ascend.utils import enable_sp
|
from vllm_ascend.utils import enable_sp
|
||||||
|
|
||||||
|
|
||||||
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -191,99 +191,88 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
|||||||
mixed_qkv_non_spec = None
|
mixed_qkv_non_spec = None
|
||||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
|
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
|
||||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
|
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
|
||||||
|
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
|
||||||
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
if spec_sequence_masks is not None:
|
||||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
|
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||||
if spec_sequence_masks is not None:
|
g_spec = g
|
||||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
beta_spec = beta
|
||||||
g_spec = g
|
g_non_spec = None
|
||||||
beta_spec = beta
|
beta_non_spec = None
|
||||||
g_non_spec = None
|
|
||||||
beta_non_spec = None
|
|
||||||
else:
|
|
||||||
g_spec = g.index_select(1, spec_token_indx)
|
|
||||||
beta_spec = beta.index_select(1, spec_token_indx)
|
|
||||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
|
||||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
|
||||||
else:
|
else:
|
||||||
g_spec = None
|
g_spec = g.index_select(1, spec_token_indx)
|
||||||
beta_spec = None
|
beta_spec = beta.index_select(1, spec_token_indx)
|
||||||
g_non_spec = g
|
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||||
beta_non_spec = beta
|
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||||
|
else:
|
||||||
|
g_spec = None
|
||||||
|
beta_spec = None
|
||||||
|
g_non_spec = g
|
||||||
|
beta_non_spec = beta
|
||||||
|
|
||||||
# 2. Recurrent attention
|
# 2. Recurrent attention
|
||||||
|
|
||||||
# 2.1: Process the multi-query part
|
# 2.1: Process the multi-query part
|
||||||
if spec_sequence_masks is not None:
|
if spec_sequence_masks is not None:
|
||||||
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
cu_seqlens = spec_query_start_loc[: attn_metadata.num_spec_decodes + 1]
|
||||||
q=query_spec,
|
actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
k=key_spec,
|
query_spec = l2norm_fwd(query_spec)
|
||||||
v=value_spec,
|
key_spec = l2norm_fwd(key_spec)
|
||||||
g=g_spec,
|
core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule(
|
||||||
beta=beta_spec,
|
query=query_spec.squeeze(0),
|
||||||
initial_state=ssm_state,
|
key=key_spec.squeeze(0),
|
||||||
inplace_final_state=True,
|
value=value_spec.squeeze(0),
|
||||||
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
|
g=g_spec.squeeze(0),
|
||||||
ssm_state_indices=spec_state_indices_tensor,
|
beta=beta_spec.squeeze(0),
|
||||||
num_accepted_tokens=num_accepted_tokens,
|
state=ssm_state,
|
||||||
use_qk_l2norm_in_kernel=True,
|
scale=key_spec.shape[-1] ** -0.5,
|
||||||
)
|
actual_seq_lengths=actual_seq_lengths,
|
||||||
else:
|
ssm_state_indices=spec_state_indices_tensor.flatten(),
|
||||||
core_attn_out_spec, last_recurrent_state = None, None
|
num_accepted_tokens=num_accepted_tokens.to(torch.int32),
|
||||||
|
).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
core_attn_out_spec, last_recurrent_state = None, None
|
||||||
|
|
||||||
# 2.2: Process the remaining part
|
# 2.2: Process the remaining part
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
|
initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous()
|
||||||
initial_state[~has_initial_state, ...] = 0
|
|
||||||
(
|
initial_state[~has_initial_state, ...] = 0
|
||||||
core_attn_out_non_spec,
|
(
|
||||||
last_recurrent_state,
|
core_attn_out_non_spec,
|
||||||
) = chunk_gated_delta_rule(
|
last_recurrent_state,
|
||||||
q=query_non_spec,
|
) = chunk_gated_delta_rule(
|
||||||
k=key_non_spec,
|
q=query_non_spec,
|
||||||
v=value_non_spec,
|
k=key_non_spec,
|
||||||
g=g_non_spec,
|
v=value_non_spec,
|
||||||
beta=beta_non_spec,
|
g=g_non_spec,
|
||||||
initial_state=initial_state,
|
beta=beta_non_spec,
|
||||||
output_final_state=True,
|
initial_state=initial_state,
|
||||||
cu_seqlens=non_spec_query_start_loc,
|
output_final_state=True,
|
||||||
head_first=False,
|
cu_seqlens=non_spec_query_start_loc,
|
||||||
use_qk_l2norm_in_kernel=True,
|
head_first=False,
|
||||||
)
|
use_qk_l2norm_in_kernel=True,
|
||||||
# Init cache
|
)
|
||||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
|
ssm_state[non_spec_state_indices_tensor] = (
|
||||||
elif attn_metadata.num_decodes > 0:
|
last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype)
|
||||||
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
)
|
||||||
q=query_non_spec,
|
|
||||||
k=key_non_spec,
|
|
||||||
v=value_non_spec,
|
|
||||||
g=g_non_spec,
|
|
||||||
beta=beta_non_spec,
|
|
||||||
initial_state=ssm_state,
|
|
||||||
inplace_final_state=True,
|
|
||||||
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
|
|
||||||
ssm_state_indices=non_spec_state_indices_tensor,
|
|
||||||
use_qk_l2norm_in_kernel=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
core_attn_out_non_spec, last_recurrent_state = None, None
|
|
||||||
|
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
|
cu_seqlens = non_spec_query_start_loc[: attn_metadata.num_decodes + 1]
|
||||||
A_log=self.A_log.contiguous(),
|
actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
dt_bias=self.dt_bias.contiguous(),
|
query_non_spec = l2norm_fwd(query_non_spec)
|
||||||
q=query_non_spec.contiguous(),
|
key_non_spec = l2norm_fwd(key_non_spec)
|
||||||
k=key_non_spec.contiguous(),
|
core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule(
|
||||||
v=value_non_spec.contiguous(),
|
query=query_non_spec.squeeze(0),
|
||||||
a=a.contiguous(),
|
key=key_non_spec.squeeze(0),
|
||||||
b=b.contiguous(),
|
value=value_non_spec.squeeze(0),
|
||||||
initial_state_source=ssm_state,
|
g=g_non_spec.squeeze(0),
|
||||||
initial_state_indices=non_spec_state_indices_tensor,
|
beta=beta_non_spec.squeeze(0),
|
||||||
cu_seqlens=non_spec_query_start_loc,
|
state=ssm_state,
|
||||||
use_qk_l2norm_in_kernel=True,
|
scale=key_non_spec.shape[-1] ** -0.5,
|
||||||
softplus_beta=1.0,
|
actual_seq_lengths=actual_seq_lengths,
|
||||||
softplus_threshold=20.0,
|
ssm_state_indices=non_spec_state_indices_tensor,
|
||||||
)
|
).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
core_attn_out_non_spec, last_recurrent_state = None, None
|
||||||
|
|
||||||
# 3. Merge core attention output
|
# 3. Merge core attention output
|
||||||
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
|
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user