[UT]add triton ops ut : test_fused_qkvzba_split_reshape_cat (#5474)
### What this PR does / why we need it?
[UT]add triton ops ut : test_fused_qkvzba_split_reshape_cat
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
pytest -sv tests/ut/ops/test_fused_qkvzba_split_reshape_cat.py
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||
|
||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
|
||||
fused_qkvzba_split_reshape_cat
|
||||
|
||||
|
||||
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
||||
y_cal = y_cal.to(device)
|
||||
y_ref = y_ref.to(device)
|
||||
if dtype == torch.float16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=5e-03,
|
||||
atol=5e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.bfloat16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=5e-03,
|
||||
atol=5e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.float32:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=1e-03,
|
||||
atol=1e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
elif dtype == torch.bool:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [1, 16, 64, 128, 256, 1024, 2048, 3567])
|
||||
@pytest.mark.parametrize("num_heads_qk", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("num_heads_v", [2, 4, 8])
|
||||
@pytest.mark.parametrize("head_qk_dim", [64, 128, 256])
|
||||
@pytest.mark.parametrize("head_v_dim", [64, 128])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_fused_qkvzba_split_reshape_cat(
|
||||
seq_len,
|
||||
num_heads_qk,
|
||||
num_heads_v,
|
||||
head_qk_dim,
|
||||
head_v_dim,
|
||||
dtype,
|
||||
):
|
||||
if num_heads_v % num_heads_qk != 0:
|
||||
pytest.skip("num_heads_v must be divisible by num_heads_qk")
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
device = "npu"
|
||||
|
||||
projected_states_qkvz = torch.randn(seq_len,
|
||||
2 * head_qk_dim * num_heads_qk +
|
||||
2 * head_v_dim * num_heads_v,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
projected_states_ba = torch.randn(seq_len,
|
||||
2 * num_heads_v,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
projected_states_qkvz_copy = projected_states_qkvz.clone()
|
||||
projected_states_ba_copy = projected_states_ba.clone()
|
||||
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz_copy,
|
||||
projected_states_ba_copy,
|
||||
num_heads_qk,
|
||||
num_heads_v,
|
||||
head_qk_dim,
|
||||
head_v_dim,
|
||||
)
|
||||
|
||||
gdn = Qwen3NextGatedDeltaNet.__new__(Qwen3NextGatedDeltaNet)
|
||||
gdn.num_k_heads = num_heads_qk
|
||||
gdn.num_v_heads = num_heads_v
|
||||
gdn.head_k_dim = head_qk_dim
|
||||
gdn.head_v_dim = head_v_dim
|
||||
gdn.tp_size = 1
|
||||
|
||||
query, key, value, z_ref, b_ref, a_ref = gdn.fix_query_key_value_ordering(
|
||||
mixed_qkvz=projected_states_qkvz, mixed_ba=projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||
(query, key, value))
|
||||
mixed_qkv_ref = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
validate_cmp(mixed_qkv, mixed_qkv_ref, dtype)
|
||||
validate_cmp(z, z_ref, dtype)
|
||||
validate_cmp(b, b_ref, dtype)
|
||||
validate_cmp(a, a_ref, dtype)
|
||||
Reference in New Issue
Block a user