### What this PR does / why we need it? qwen3-next suppot triton chunk_gated_delta_rule ops ### co-owners @OsirisDuan - vLLM version: v0.11.2 Signed-off-by: shiyuan680 <917935075@qq.com>
34 lines
1.4 KiB
Python
34 lines
1.4 KiB
Python
import torch
|
|
|
|
from tests.ut.base import PytestBase
|
|
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
|
|
|
|
|
|
class TestChunkGatedDeltaRule(PytestBase):
|
|
|
|
def test_triton_fusion_ops(self, mock_moe_env):
|
|
q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
|
|
k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
|
|
v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu()
|
|
g = torch.randn(1, 17, 8, dtype=torch.float32).npu()
|
|
beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu()
|
|
initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu()
|
|
q_start_loc = torch.range(0, 3, dtype=torch.int).npu()
|
|
|
|
(
|
|
core_attn_out_non_spec,
|
|
last_recurrent_state,
|
|
) = chunk_gated_delta_rule(q=q,
|
|
k=k,
|
|
v=v,
|
|
g=g,
|
|
beta=beta,
|
|
initial_state=initial_state,
|
|
output_final_state=True,
|
|
cu_seqlens=q_start_loc,
|
|
head_first=False,
|
|
use_qk_l2norm_in_kernel=True)
|
|
|
|
assert core_attn_out_non_spec.shape == (1, 17, 8, 128)
|
|
assert last_recurrent_state.shape == (3, 8, 128, 128)
|