Files
xc-llm-ascend/tests/e2e/multicard/test_chunk_gated_delta_rule.py
shiyuan680 1c4a0468ee 【OPS】qwen3-next support triton chunk_gated_delta_rule ops (#4070)
### What this PR does / why we need it?
qwen3-next suppot  triton chunk_gated_delta_rule ops

### co-owners
@OsirisDuan

- vLLM version: v0.11.2

Signed-off-by: shiyuan680 <917935075@qq.com>
2025-11-28 20:55:43 +08:00

34 lines
1.4 KiB
Python

import torch
from tests.ut.base import PytestBase
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
class TestChunkGatedDeltaRule(PytestBase):
def test_triton_fusion_ops(self, mock_moe_env):
q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu()
g = torch.randn(1, 17, 8, dtype=torch.float32).npu()
beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu()
initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu()
q_start_loc = torch.range(0, 3, dtype=torch.int).npu()
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=q_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True)
assert core_attn_out_non_spec.shape == (1, 17, 8, 128)
assert last_recurrent_state.shape == (3, 8, 128, 128)