[qwen3 next ]add ascend c casual_conv1d_fn (#6661)
### What this PR does / why we need it?
add ascend c casual_conv1d_fn
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -22,11 +22,12 @@ from einops import rearrange
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
|
||||
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||
@@ -163,20 +164,18 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
# 1.2: Process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if mixed_qkv_non_spec is not None:
|
||||
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||
mixed_qkv_non_spec_T,
|
||||
conv_weights,
|
||||
conv_weights_T = conv_weights.transpose(0, 1)
|
||||
mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn(
|
||||
mixed_qkv_non_spec,
|
||||
conv_weights_T,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
conv_state=self_kv_cache[0],
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
metadata=attn_metadata,
|
||||
).transpose(0, 1)
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv_non_spec,
|
||||
|
||||
Reference in New Issue
Block a user