[Ops][Misc] Optimize split_qkv_rmsnorm_rope op (#6827)

### What this PR does / why we need it?

This PR optimizes the `split_qkv_rmsnorm_rope` operator by introducing a
new Triton kernel, `split_qkv_rmsnorm_rope_prefill_kernel`, for the
prefill stage (i.e., large batch sizes). The implementation now
dynamically selects between the existing decode kernel and the new
prefill kernel based on the batch size, which improves performance for
large batch scenarios.

Additionally, the RoPE implementation is updated to support partial
rotation dimensions (`rope_dim`), making the operator more flexible.

### Does this PR introduce _any_ user-facing change?

No. This is a performance optimization and is not expected to introduce
any user-facing changes.

### How was this patch tested?

CI should pass with existing tests. The new prefill path is triggered
when the batch size is larger than the number of available vector cores.
The partial RoPE feature can be tested by passing the `rope_dim`
argument.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: guzhiyong <guzhiyong5@h-partners.com>
Signed-off-by: frank <2547457096@qq.com>
Co-authored-by: guzhiyong <guzhiyong5@h-partners.com>
This commit is contained in:
frank
2026-03-06 09:30:31 +08:00
committed by GitHub
parent a60e179c7f
commit 18b52afe2b
4 changed files with 290 additions and 177 deletions

View File

@@ -11,6 +11,7 @@ MAX_POSITION_EMBEDDINGS = [262144]
NUM_TOKENS = [1, 4, 8, 16, 1024]
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
HEAD_SIZES = [128]
ROPE_DIMS = [64, 128]
EPS = [1e-6]
DTYPES = [torch.bfloat16]
SEEDS = [0]
@@ -23,19 +24,26 @@ def custom_rope(q, k, sin, cos):
rotary_dim = sin.shape[-1]
sin = sin.to(torch.float32)
cos = cos.to(torch.float32)
x1 = q[..., :rotary_dim // 2]
x2 = q[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = q * cos
res1 = mul1 + mul2
q_rot = q[..., :rotary_dim]
k_rot = k[..., :rotary_dim]
q_pass = q[..., rotary_dim:]
k_pass = k[..., rotary_dim:]
x1 = k[..., :rotary_dim // 2]
x2 = k[..., rotary_dim // 2:]
x1 = q_rot[..., :rotary_dim // 2]
x2 = q_rot[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = k * cos
res2 = mul1 + mul2
mul2 = q_rot * cos
q_rot = mul1 + mul2
res1 = torch.cat([q_rot, q_pass], dim=-1)
x1 = k_rot[..., :rotary_dim // 2]
x2 = k_rot[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = k_rot * cos
k_rot = mul1 + mul2
res2 = torch.cat([k_rot, k_pass], dim=-1)
return res1, res2
@@ -64,9 +72,10 @@ def rms_norm(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device):
head_size, eps, dtype, seed, device, rope_dim):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
@@ -81,7 +90,7 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
k_weight = torch.randn(head_size, dtype=dtype, device=device)
cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1,
[max_position_embeddings, head_size])).to(dtype).npu()
[max_position_embeddings, rope_dim])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
@@ -141,10 +150,11 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype,
seed, device):
seed, device, rope_dim):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
@@ -161,7 +171,7 @@ def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, n
k_bias = torch.randn(head_size, dtype=dtype, device=device)
cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1,
[max_position_embeddings, head_size])).to(dtype).npu()
[max_position_embeddings, rope_dim])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,