[CI] mv ops to correct path (#5615)

### What this PR does / why we need it?
mv ops to correct path
:`tests/e2e/nightly/single_node/ops/singlecard_ops/triton`

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2026-01-05 23:17:07 +08:00
committed by GitHub
parent 129ba9fe1b
commit c5e2f48510
3 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,100 @@
import pytest
import torch
from einops import rearrange
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
fused_qkvzba_split_reshape_cat
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
y_ref = y_ref.to(device)
if dtype == torch.float16:
torch.testing.assert_close(y_ref,
y_cal,
rtol=5e-03,
atol=5e-03,
equal_nan=True)
elif dtype == torch.bfloat16:
torch.testing.assert_close(y_ref,
y_cal,
rtol=5e-03,
atol=5e-03,
equal_nan=True)
elif dtype == torch.float32:
torch.testing.assert_close(y_ref,
y_cal,
rtol=1e-03,
atol=1e-03,
equal_nan=True)
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
assert torch.equal(y_cal, y_ref)
elif dtype == torch.bool:
assert torch.equal(y_cal, y_ref)
else:
raise ValueError(
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
@pytest.mark.parametrize("seq_len", [1, 16, 64, 128, 256, 1024, 2048, 3567])
@pytest.mark.parametrize("num_heads_qk", [2, 4, 8, 16])
@pytest.mark.parametrize("num_heads_v", [2, 4, 8])
@pytest.mark.parametrize("head_qk_dim", [64, 128, 256])
@pytest.mark.parametrize("head_v_dim", [64, 128])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_fused_qkvzba_split_reshape_cat(
seq_len,
num_heads_qk,
num_heads_v,
head_qk_dim,
head_v_dim,
dtype,
):
if num_heads_v % num_heads_qk != 0:
pytest.skip("num_heads_v must be divisible by num_heads_qk")
torch.random.manual_seed(0)
device = "npu"
projected_states_qkvz = torch.randn(seq_len,
2 * head_qk_dim * num_heads_qk +
2 * head_v_dim * num_heads_v,
dtype=dtype,
device=device)
projected_states_ba = torch.randn(seq_len,
2 * num_heads_v,
dtype=dtype,
device=device)
projected_states_qkvz_copy = projected_states_qkvz.clone()
projected_states_ba_copy = projected_states_ba.clone()
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz_copy,
projected_states_ba_copy,
num_heads_qk,
num_heads_v,
head_qk_dim,
head_v_dim,
)
gdn = Qwen3NextGatedDeltaNet.__new__(Qwen3NextGatedDeltaNet)
gdn.num_k_heads = num_heads_qk
gdn.num_v_heads = num_heads_v
gdn.head_k_dim = head_qk_dim
gdn.head_v_dim = head_v_dim
gdn.tp_size = 1
query, key, value, z_ref, b_ref, a_ref = gdn.fix_query_key_value_ordering(
mixed_qkvz=projected_states_qkvz, mixed_ba=projected_states_ba)
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
(query, key, value))
mixed_qkv_ref = torch.cat((query, key, value), dim=-1)
validate_cmp(mixed_qkv, mixed_qkv_ref, dtype)
validate_cmp(z, z_ref, dtype)
validate_cmp(b, b_ref, dtype)
validate_cmp(a, a_ref, dtype)

View File

@@ -0,0 +1,95 @@
import pytest
import torch
from vllm.v1.sample.rejection_sampler import \
rejection_random_sample_kernel as original_rejection_random_sample_kernel
from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, rejection_random_sample_kernel)
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.fixture(scope="function", autouse=True)
def setup_device_properties():
init_device_properties_triton()
yield
@pytest.mark.parametrize("max_spec_len", [1, 2, 3])
@pytest.mark.parametrize("vocab_size", [151_936])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 64, 128, 256, 512, 1024])
@torch.inference_mode()
def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
device = 'npu'
torch.manual_seed(0)
draft_probs = torch.rand(batch_size * max_spec_len,
vocab_size,
dtype=torch.float32,
device=device)
target_probs = torch.rand(batch_size * max_spec_len,
vocab_size,
dtype=torch.float32,
device=device)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=device)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size * max_spec_len, ),
dtype=torch.int64,
device=device)
output_token_ids = torch.empty((batch_size, max_spec_len + 1),
dtype=torch.int64,
device=device)
original_output_token_ids = output_token_ids.clone()
num_tokens = draft_token_ids.shape[0]
uniform_probs = torch.rand((num_tokens, ),
dtype=torch.float32,
device=device)
num_draft_tokens = [max_spec_len] * batch_size
num_draft_tokens = torch.tensor(num_draft_tokens,
dtype=torch.int32,
device=device)
cu_num_draft_tokens = torch.cumsum(num_draft_tokens,
dim=0,
dtype=torch.int32)
is_greedy_ptr = torch.full((batch_size, ),
False,
dtype=torch.bool,
device=device)
recovered_ids = torch.zeros_like(draft_token_ids,
dtype=torch.int64,
device=device)
grid, block_size = cal_grid_and_block_size(batch_size)
original_rejection_random_sample_kernel[(batch_size, )](
original_output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_ids,
uniform_probs,
is_greedy_ptr,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
)
rejection_random_sample_kernel[(grid, )](output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_ids,
uniform_probs,
is_greedy_ptr,
max_spec_len,
vocab_size,
batch_size,
NO_DRAFT_PROBS=draft_probs
is None,
BLOCK_SIZE=block_size)
torch.npu.synchronize()
assert torch.equal(original_output_token_ids, output_token_ids)

View File

@@ -0,0 +1,214 @@
import gc
import numpy as np
import pytest
import torch
import vllm_ascend.ops.register_custom_ops # noqa
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
NUM_TOKENS = [1, 4, 8, 16, 1024]
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
HEAD_SIZES = [128]
EPS = [1e-6]
DTYPES = [torch.bfloat16]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 5e-2
DEFAULT_RTOL = 5e-3
def custom_rope(q, k, sin, cos):
rotary_dim = sin.shape[-1]
sin = sin.to(torch.float32)
cos = cos.to(torch.float32)
x1 = q[..., :rotary_dim // 2]
x2 = q[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = q * cos
res1 = mul1 + mul2
x1 = k[..., :rotary_dim // 2]
x2 = k[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = k * cos
res2 = mul1 + mul2
return res1, res2
def rms_norm(
input,
norm_weight,
eps,
norm_bias=None,
):
input = input.to(torch.float32)
norm_weight = norm_weight.to(torch.float32)
reciprocal_std = 1 / torch.sqrt(
torch.mean(input**2, axis=-1, keepdims=True) + eps)
out = input * reciprocal_std * norm_weight
if norm_bias is not None:
norm_bias = norm_bias.to(torch.float32)
out = out + norm_bias
return out
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
cos=cos,
sin=sin)
# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size), q_weight.cpu(), eps)
_k = rms_norm(_k.reshape(-1, head_size), k_weight.cpu(), eps)
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)
# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)
# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype,
seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
q_bias = torch.randn(head_size, dtype=dtype, device=device)
k_bias = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size),
q_weight.cpu(),
eps,
norm_bias=q_bias.cpu())
_k = rms_norm(_k.reshape(-1, head_size),
k_weight.cpu(),
eps,
norm_bias=k_bias.cpu())
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)
# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)
# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()