【OPS】qwen3-next support triton chunk_gated_delta_rule ops (#4070)

### What this PR does / why we need it?
qwen3-next suppot  triton chunk_gated_delta_rule ops

### co-owners
@OsirisDuan

- vLLM version: v0.11.2

Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
shiyuan680
2025-11-28 20:55:43 +08:00
committed by GitHub
parent 5447a039b9
commit 1c4a0468ee
13 changed files with 1625 additions and 149 deletions

View File

@@ -0,0 +1,33 @@
import torch
from tests.ut.base import PytestBase
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
class TestChunkGatedDeltaRule(PytestBase):
def test_triton_fusion_ops(self, mock_moe_env):
q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu()
g = torch.randn(1, 17, 8, dtype=torch.float32).npu()
beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu()
initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu()
q_start_loc = torch.range(0, 3, dtype=torch.int).npu()
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=q_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True)
assert core_attn_out_non_spec.shape == (1, 17, 8, 128)
assert last_recurrent_state.shape == (3, 8, 128, 128)