[CI] mv ops to correct path (#5615)
### What this PR does / why we need it? mv ops to correct path :`tests/e2e/nightly/single_node/ops/singlecard_ops/triton` Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||
|
||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
|
||||
fused_qkvzba_split_reshape_cat
|
||||
|
||||
|
||||
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
||||
y_cal = y_cal.to(device)
|
||||
y_ref = y_ref.to(device)
|
||||
if dtype == torch.float16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=5e-03,
|
||||
atol=5e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.bfloat16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=5e-03,
|
||||
atol=5e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.float32:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=1e-03,
|
||||
atol=1e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
elif dtype == torch.bool:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [1, 16, 64, 128, 256, 1024, 2048, 3567])
|
||||
@pytest.mark.parametrize("num_heads_qk", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("num_heads_v", [2, 4, 8])
|
||||
@pytest.mark.parametrize("head_qk_dim", [64, 128, 256])
|
||||
@pytest.mark.parametrize("head_v_dim", [64, 128])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_fused_qkvzba_split_reshape_cat(
|
||||
seq_len,
|
||||
num_heads_qk,
|
||||
num_heads_v,
|
||||
head_qk_dim,
|
||||
head_v_dim,
|
||||
dtype,
|
||||
):
|
||||
if num_heads_v % num_heads_qk != 0:
|
||||
pytest.skip("num_heads_v must be divisible by num_heads_qk")
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
device = "npu"
|
||||
|
||||
projected_states_qkvz = torch.randn(seq_len,
|
||||
2 * head_qk_dim * num_heads_qk +
|
||||
2 * head_v_dim * num_heads_v,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
projected_states_ba = torch.randn(seq_len,
|
||||
2 * num_heads_v,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
projected_states_qkvz_copy = projected_states_qkvz.clone()
|
||||
projected_states_ba_copy = projected_states_ba.clone()
|
||||
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz_copy,
|
||||
projected_states_ba_copy,
|
||||
num_heads_qk,
|
||||
num_heads_v,
|
||||
head_qk_dim,
|
||||
head_v_dim,
|
||||
)
|
||||
|
||||
gdn = Qwen3NextGatedDeltaNet.__new__(Qwen3NextGatedDeltaNet)
|
||||
gdn.num_k_heads = num_heads_qk
|
||||
gdn.num_v_heads = num_heads_v
|
||||
gdn.head_k_dim = head_qk_dim
|
||||
gdn.head_v_dim = head_v_dim
|
||||
gdn.tp_size = 1
|
||||
|
||||
query, key, value, z_ref, b_ref, a_ref = gdn.fix_query_key_value_ordering(
|
||||
mixed_qkvz=projected_states_qkvz, mixed_ba=projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||
(query, key, value))
|
||||
mixed_qkv_ref = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
validate_cmp(mixed_qkv, mixed_qkv_ref, dtype)
|
||||
validate_cmp(z, z_ref, dtype)
|
||||
validate_cmp(b, b_ref, dtype)
|
||||
validate_cmp(a, a_ref, dtype)
|
||||
@@ -0,0 +1,95 @@
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.v1.sample.rejection_sampler import \
|
||||
rejection_random_sample_kernel as original_rejection_random_sample_kernel
|
||||
|
||||
from vllm_ascend.ops.triton.reject_sample import (
|
||||
cal_grid_and_block_size, rejection_random_sample_kernel)
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_device_properties():
|
||||
init_device_properties_triton()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_spec_len", [1, 2, 3])
|
||||
@pytest.mark.parametrize("vocab_size", [151_936])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 64, 128, 256, 512, 1024])
|
||||
@torch.inference_mode()
|
||||
def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
|
||||
device = 'npu'
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size * max_spec_len,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
target_probs = torch.rand(batch_size * max_spec_len,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size * max_spec_len, ),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
output_token_ids = torch.empty((batch_size, max_spec_len + 1),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
original_output_token_ids = output_token_ids.clone()
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
uniform_probs = torch.rand((num_tokens, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
num_draft_tokens = [max_spec_len] * batch_size
|
||||
num_draft_tokens = torch.tensor(num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_num_draft_tokens = torch.cumsum(num_draft_tokens,
|
||||
dim=0,
|
||||
dtype=torch.int32)
|
||||
is_greedy_ptr = torch.full((batch_size, ),
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
recovered_ids = torch.zeros_like(draft_token_ids,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||
original_rejection_random_sample_kernel[(batch_size, )](
|
||||
original_output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_ids,
|
||||
uniform_probs,
|
||||
is_greedy_ptr,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
rejection_random_sample_kernel[(grid, )](output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_ids,
|
||||
uniform_probs,
|
||||
is_greedy_ptr,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
batch_size,
|
||||
NO_DRAFT_PROBS=draft_probs
|
||||
is None,
|
||||
BLOCK_SIZE=block_size)
|
||||
torch.npu.synchronize()
|
||||
assert torch.equal(original_output_token_ids, output_token_ids)
|
||||
@@ -0,0 +1,214 @@
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
NUM_TOKENS = [1, 4, 8, 16, 1024]
|
||||
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
|
||||
HEAD_SIZES = [128]
|
||||
EPS = [1e-6]
|
||||
DTYPES = [torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
DEFAULT_ATOL = 5e-2
|
||||
DEFAULT_RTOL = 5e-3
|
||||
|
||||
|
||||
def custom_rope(q, k, sin, cos):
|
||||
rotary_dim = sin.shape[-1]
|
||||
sin = sin.to(torch.float32)
|
||||
cos = cos.to(torch.float32)
|
||||
x1 = q[..., :rotary_dim // 2]
|
||||
x2 = q[..., rotary_dim // 2:]
|
||||
cat_x = torch.cat([-x2, x1], axis=-1)
|
||||
mul1 = cat_x * sin
|
||||
mul2 = q * cos
|
||||
res1 = mul1 + mul2
|
||||
|
||||
x1 = k[..., :rotary_dim // 2]
|
||||
x2 = k[..., rotary_dim // 2:]
|
||||
cat_x = torch.cat([-x2, x1], axis=-1)
|
||||
mul1 = cat_x * sin
|
||||
mul2 = k * cos
|
||||
res2 = mul1 + mul2
|
||||
return res1, res2
|
||||
|
||||
|
||||
def rms_norm(
|
||||
input,
|
||||
norm_weight,
|
||||
eps,
|
||||
norm_bias=None,
|
||||
):
|
||||
input = input.to(torch.float32)
|
||||
norm_weight = norm_weight.to(torch.float32)
|
||||
reciprocal_std = 1 / torch.sqrt(
|
||||
torch.mean(input**2, axis=-1, keepdims=True) + eps)
|
||||
out = input * reciprocal_std * norm_weight
|
||||
if norm_bias is not None:
|
||||
norm_bias = norm_bias.to(torch.float32)
|
||||
out = out + norm_bias
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("eps", EPS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
|
||||
head_size, eps, dtype, seed, device):
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
|
||||
q_hidden_size = num_q_heads * head_size
|
||||
kv_hidden_size = num_kv_heads * head_size
|
||||
qkv = torch.randn(num_tokens,
|
||||
q_hidden_size + kv_hidden_size * 2,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
k_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
sin = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
cos = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
# fused kernel
|
||||
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
q_hidden_size=q_hidden_size,
|
||||
kv_hidden_size=kv_hidden_size,
|
||||
head_dim=head_size,
|
||||
eps=eps,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
|
||||
# split
|
||||
_q, _k, v_gold = qkv.cpu().split(
|
||||
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
|
||||
# norm
|
||||
_q = rms_norm(_q.reshape(-1, head_size), q_weight.cpu(), eps)
|
||||
_k = rms_norm(_k.reshape(-1, head_size), k_weight.cpu(), eps)
|
||||
_q = _q.reshape(num_tokens, 1, -1, head_size)
|
||||
_k = _k.reshape(num_tokens, 1, -1, head_size)
|
||||
|
||||
# rope
|
||||
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
|
||||
q_gold = q_gold.reshape(num_tokens, -1)
|
||||
k_gold = k_gold.reshape(num_tokens, -1)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(q.to(torch.float32).cpu(),
|
||||
q_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(k.to(torch.float32).cpu(),
|
||||
k_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(v.to(torch.float32).cpu(),
|
||||
v_gold.to(torch.float32),
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("eps", EPS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
|
||||
num_kv_heads, head_size, eps, dtype,
|
||||
seed, device):
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
|
||||
q_hidden_size = num_q_heads * head_size
|
||||
kv_hidden_size = num_kv_heads * head_size
|
||||
qkv = torch.randn(num_tokens,
|
||||
q_hidden_size + kv_hidden_size * 2,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
k_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
q_bias = torch.randn(head_size, dtype=dtype, device=device)
|
||||
k_bias = torch.randn(head_size, dtype=dtype, device=device)
|
||||
sin = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
cos = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
# fused kernel
|
||||
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
q_hidden_size=q_hidden_size,
|
||||
kv_hidden_size=kv_hidden_size,
|
||||
head_dim=head_size,
|
||||
eps=eps,
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
|
||||
# split
|
||||
_q, _k, v_gold = qkv.cpu().split(
|
||||
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
|
||||
# norm
|
||||
_q = rms_norm(_q.reshape(-1, head_size),
|
||||
q_weight.cpu(),
|
||||
eps,
|
||||
norm_bias=q_bias.cpu())
|
||||
_k = rms_norm(_k.reshape(-1, head_size),
|
||||
k_weight.cpu(),
|
||||
eps,
|
||||
norm_bias=k_bias.cpu())
|
||||
_q = _q.reshape(num_tokens, 1, -1, head_size)
|
||||
_k = _k.reshape(num_tokens, 1, -1, head_size)
|
||||
|
||||
# rope
|
||||
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
|
||||
q_gold = q_gold.reshape(num_tokens, -1)
|
||||
k_gold = k_gold.reshape(num_tokens, -1)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(q.to(torch.float32).cpu(),
|
||||
q_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(k.to(torch.float32).cpu(),
|
||||
k_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(v.to(torch.float32).cpu(),
|
||||
v_gold.to(torch.float32),
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
Reference in New Issue
Block a user